From d3fe02e46772eb147529e8cf3398f7d81fd1f2e3 Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Fri, 14 Mar 2025 15:38:33 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=A2=9E=E5=A4=A7=E4=BA=86=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E7=9A=84maxtoken=E9=98=B2=E6=AD=A2=E6=BA=A2=E5=87=BA?= =?UTF-8?q?=EF=BC=8Cmessagecq=E6=94=B9=E5=BC=82=E6=AD=A5get=5Fimage?= =?UTF-8?q?=E9=98=B2=E6=AD=A2=E9=98=BB=E5=A1=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/chat/bot.py | 11 +- src/plugins/chat/cq_code.py | 215 +++++++++++------------------ src/plugins/chat/emoji_manager.py | 5 +- src/plugins/chat/llm_generator.py | 60 +++----- src/plugins/chat/message_cq.py | 23 +-- src/plugins/chat/prompt_builder.py | 177 ++++++++++++------------ src/plugins/chat/utils_image.py | 2 +- 7 files changed, 207 insertions(+), 286 deletions(-) diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index b8624dae0..65aa3702d 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -74,6 +74,7 @@ class ChatBot: reply_message=None, platform="qq", ) + await message_cq.initialize() message_json = message_cq.to_dict() # 进入maimbot @@ -120,8 +121,13 @@ class ChatBot: # 用户屏蔽,不区分私聊/群聊 if event.user_id in global_config.ban_user_id: return - - if event.reply and hasattr(event.reply, 'sender') and hasattr(event.reply.sender, 'user_id') and event.reply.sender.user_id in global_config.ban_user_id: + + if ( + event.reply + and hasattr(event.reply, "sender") + and hasattr(event.reply.sender, "user_id") + and event.reply.sender.user_id in global_config.ban_user_id + ): logger.debug(f"跳过处理回复来自被ban用户 {event.reply.sender.user_id} 的消息") return # 处理私聊消息 @@ -171,6 +177,7 @@ class ChatBot: reply_message=event.reply, platform="qq", ) + await message_cq.initialize() message_json = message_cq.to_dict() # 进入maimbot diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index 049419f1c..2edc011b2 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -1,48 +1,28 @@ import base64 import html import time +import asyncio from dataclasses import dataclass from typing import Dict, List, Optional, Union - +import ssl import os - -import requests - -# 解析各种CQ码 -# 包含CQ码类 -import urllib3 +import aiohttp from loguru import logger from nonebot import get_driver -from urllib3.util import create_urllib3_context from ..models.utils_model import LLM_request from .config import global_config from .mapper import emojimapper from .message_base import Seg -from .utils_user import get_user_nickname,get_groupname +from .utils_user import get_user_nickname, get_groupname from .message_base import GroupInfo, UserInfo driver = get_driver() config = driver.config -# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616 -ctx = create_urllib3_context() -ctx.load_default_certs() -ctx.set_ciphers("AES128-GCM-SHA256") - - -class TencentSSLAdapter(requests.adapters.HTTPAdapter): - def __init__(self, ssl_context=None, **kwargs): - self.ssl_context = ssl_context - super().__init__(**kwargs) - - def init_poolmanager(self, connections, maxsize, block=False): - self.poolmanager = urllib3.poolmanager.PoolManager( - num_pools=connections, - maxsize=maxsize, - block=block, - ssl_context=self.ssl_context, - ) +# 创建SSL上下文 +ssl_context = ssl.create_default_context() +ssl_context.set_ciphers("AES128-GCM-SHA256") @dataclass @@ -70,14 +50,12 @@ class CQCode: """初始化LLM实例""" pass - def translate(self): + async def translate(self): """根据CQ码类型进行相应的翻译处理,转换为Seg对象""" if self.type == "text": - self.translated_segments = Seg( - type="text", data=self.params.get("text", "") - ) + self.translated_segments = Seg(type="text", data=self.params.get("text", "")) elif self.type == "image": - base64_data = self.translate_image() + base64_data = await self.translate_image() if base64_data: if self.params.get("sub_type") == "0": self.translated_segments = Seg(type="image", data=base64_data) @@ -88,24 +66,20 @@ class CQCode: elif self.type == "at": if self.params.get("qq") == "all": self.translated_segments = Seg(type="text", data="@[全体成员]") - else: + else: user_nickname = get_user_nickname(self.params.get("qq", "")) - self.translated_segments = Seg( - type="text", data=f"[@{user_nickname or '某人'}]" - ) + self.translated_segments = Seg(type="text", data=f"[@{user_nickname or '某人'}]") elif self.type == "reply": - reply_segments = self.translate_reply() + reply_segments = await self.translate_reply() if reply_segments: self.translated_segments = Seg(type="seglist", data=reply_segments) else: self.translated_segments = Seg(type="text", data="[回复某人消息]") elif self.type == "face": face_id = self.params.get("id", "") - self.translated_segments = Seg( - type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]" - ) + self.translated_segments = Seg(type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]") elif self.type == "forward": - forward_segments = self.translate_forward() + forward_segments = await self.translate_forward() if forward_segments: self.translated_segments = Seg(type="seglist", data=forward_segments) else: @@ -113,18 +87,8 @@ class CQCode: else: self.translated_segments = Seg(type="text", data=f"[{self.type}]") - def get_img(self): - """ - headers = { - 'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0', - 'Accept': 'image/*;q=0.8', - 'Accept-Encoding': 'gzip, deflate, br', - 'Connection': 'keep-alive', - 'Cache-Control': 'no-cache', - 'Pragma': 'no-cache' - } - """ - # 腾讯专用请求头配置 + async def get_img(self) -> Optional[str]: + """异步获取图片并转换为base64""" headers = { "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36", "Accept": "text/html, application/xhtml xml, */*", @@ -133,61 +97,63 @@ class CQCode: "Content-Type": "application/x-www-form-urlencoded", "Cache-Control": "no-cache", } + url = html.unescape(self.params["url"]) if not url.startswith(("http://", "https://")): return None - # 创建专用会话 - session = requests.session() - session.adapters.pop("https://", None) - session.mount("https://", TencentSSLAdapter(ctx)) - max_retries = 3 for retry in range(max_retries): try: - response = session.get( - url, - headers=headers, - timeout=15, - allow_redirects=True, - stream=True, # 流式传输避免大内存问题 - ) + logger.debug(f"获取图片中: {url}") + # 设置SSL上下文和创建连接器 + conn = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=conn) as session: + async with session.get( + url, + headers=headers, + timeout=aiohttp.ClientTimeout(total=15), + allow_redirects=True, + ) as response: + # 腾讯服务器特殊状态码处理 + if response.status == 400 and "multimedia.nt.qq.com.cn" in url: + return None - # 腾讯服务器特殊状态码处理 - if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url: - return None + if response.status != 200: + raise aiohttp.ClientError(f"HTTP {response.status}") - if response.status_code != 200: - raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") + # 验证内容类型 + content_type = response.headers.get("Content-Type", "") + if not content_type.startswith("image/"): + raise ValueError(f"非图片内容类型: {content_type}") - # 验证内容类型 - content_type = response.headers.get("Content-Type", "") - if not content_type.startswith("image/"): - raise ValueError(f"非图片内容类型: {content_type}") + # 读取响应内容 + content = await response.read() + logger.debug(f"获取图片成功: {url}") - # 转换为Base64 - image_base64 = base64.b64encode(response.content).decode("utf-8") - self.image_base64 = image_base64 - return image_base64 + # 转换为Base64 + image_base64 = base64.b64encode(content).decode("utf-8") + self.image_base64 = image_base64 + return image_base64 - except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e: + except (aiohttp.ClientError, ValueError) as e: if retry == max_retries - 1: logger.error(f"最终请求失败: {str(e)}") - time.sleep(1.5**retry) # 指数退避 + await asyncio.sleep(1.5**retry) # 指数退避 - except Exception: - logger.exception("[未知错误]") + except Exception as e: + logger.exception(f"获取图片时发生未知错误: {str(e)}") return None return None - def translate_image(self) -> Optional[str]: + async def translate_image(self) -> Optional[str]: """处理图片类型的CQ码,返回base64字符串""" if "url" not in self.params: return None - return self.get_img() + return await self.get_img() - def translate_forward(self) -> Optional[List[Seg]]: + async def translate_forward(self) -> Optional[List[Seg]]: """处理转发消息,返回Seg列表""" try: if "content" not in self.params: @@ -217,15 +183,16 @@ class CQCode: else: if raw_message: from .message_cq import MessageRecvCQ - user_info=UserInfo( - platform='qq', + + user_info = UserInfo( + platform="qq", user_id=msg.get("user_id", 0), user_nickname=nickname, ) - group_info=GroupInfo( - platform='qq', + group_info = GroupInfo( + platform="qq", group_id=msg.get("group_id", 0), - group_name=get_groupname(msg.get("group_id", 0)) + group_name=get_groupname(msg.get("group_id", 0)), ) message_obj = MessageRecvCQ( @@ -235,24 +202,23 @@ class CQCode: plain_text=raw_message, group_info=group_info, ) - content_seg = Seg( - type="seglist", data=[message_obj.message_segment] - ) + await message_obj.initialize() + content_seg = Seg(type="seglist", data=[message_obj.message_segment]) else: content_seg = Seg(type="text", data="[空消息]") else: if raw_message: from .message_cq import MessageRecvCQ - user_info=UserInfo( - platform='qq', + user_info = UserInfo( + platform="qq", user_id=msg.get("user_id", 0), user_nickname=nickname, ) - group_info=GroupInfo( - platform='qq', + group_info = GroupInfo( + platform="qq", group_id=msg.get("group_id", 0), - group_name=get_groupname(msg.get("group_id", 0)) + group_name=get_groupname(msg.get("group_id", 0)), ) message_obj = MessageRecvCQ( message_id=msg.get("message_id", 0), @@ -261,9 +227,8 @@ class CQCode: plain_text=raw_message, group_info=group_info, ) - content_seg = Seg( - type="seglist", data=[message_obj.message_segment] - ) + await message_obj.initialize() + content_seg = Seg(type="seglist", data=[message_obj.message_segment]) else: content_seg = Seg(type="text", data="[空消息]") @@ -277,7 +242,7 @@ class CQCode: logger.error(f"处理转发消息失败: {str(e)}") return None - def translate_reply(self) -> Optional[List[Seg]]: + async def translate_reply(self) -> Optional[List[Seg]]: """处理回复类型的CQ码,返回Seg列表""" from .message_cq import MessageRecvCQ @@ -285,22 +250,19 @@ class CQCode: return None if self.reply_message.sender.user_id: - message_obj = MessageRecvCQ( - user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.nickname), + user_info=UserInfo( + user_id=self.reply_message.sender.user_id, user_nickname=self.reply_message.sender.nickname + ), message_id=self.reply_message.message_id, raw_message=str(self.reply_message.message), group_info=GroupInfo(group_id=self.reply_message.group_id), ) - + await message_obj.initialize() segments = [] if message_obj.message_info.user_info.user_id == global_config.BOT_QQ: - segments.append( - Seg( - type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: " - ) - ) + segments.append(Seg(type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: ")) else: segments.append( Seg( @@ -318,16 +280,12 @@ class CQCode: @staticmethod def unescape(text: str) -> str: """反转义CQ码中的特殊字符""" - return ( - text.replace(",", ",") - .replace("[", "[") - .replace("]", "]") - .replace("&", "&") - ) + return text.replace(",", ",").replace("[", "[").replace("]", "]").replace("&", "&") + class CQCode_tool: @staticmethod - def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode: + def cq_from_dict_to_class(cq_code: Dict, msg, reply: Optional[Dict] = None) -> CQCode: """ 将CQ码字典转换为CQCode对象 @@ -353,11 +311,9 @@ class CQCode_tool: params=params, group_info=msg.message_info.group_info, user_info=msg.message_info.user_info, - reply_message=reply + reply_message=reply, ) - # 进行翻译处理 - instance.translate() return instance @staticmethod @@ -383,12 +339,7 @@ class CQCode_tool: # 确保使用绝对路径 abs_path = os.path.abspath(file_path) # 转义特殊字符 - escaped_path = ( - abs_path.replace("&", "&") - .replace("[", "[") - .replace("]", "]") - .replace(",", ",") - ) + escaped_path = abs_path.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",") # 生成CQ码,设置sub_type=1表示这是表情包 return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" @@ -403,14 +354,11 @@ class CQCode_tool: """ # 转义base64数据 escaped_base64 = ( - base64_data.replace("&", "&") - .replace("[", "[") - .replace("]", "]") - .replace(",", ",") + base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",") ) # 生成CQ码,设置sub_type=1表示这是表情包 return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]" - + @staticmethod def create_image_cq_base64(base64_data: str) -> str: """ @@ -422,10 +370,7 @@ class CQCode_tool: """ # 转义base64数据 escaped_base64 = ( - base64_data.replace("&", "&") - .replace("[", "[") - .replace("]", "]") - .replace(",", ",") + base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",") ) # 生成CQ码,设置sub_type=1表示这是表情包 return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]" diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index e3342d1a7..4ac1af73e 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -37,7 +37,7 @@ class EmojiManager: self._scan_task = None self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) self.llm_emotion_judge = LLM_request( - model=global_config.llm_emotion_judge, max_tokens=60, temperature=0.8 + model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8 ) # 更高的温度,更少的token(后续可以根据情绪来调整温度) def _ensure_emoji_dir(self): @@ -275,9 +275,6 @@ class EmojiManager: continue logger.info(f"check通过 {check}") - if description is not None: - embedding = await get_embedding(description) - if description is not None: embedding = await get_embedding(description) diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index 2e0c0eb1f..a76f98dfb 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -25,30 +25,19 @@ class ResponseGenerator: max_tokens=1000, stream=True, ) - self.model_v3 = LLM_request( - model=global_config.llm_normal, temperature=0.7, max_tokens=1000 - ) - self.model_r1_distill = LLM_request( - model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000 - ) - self.model_v25 = LLM_request( - model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000 - ) + self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7, max_tokens=3000) + self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=3000) + self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7, max_tokens=3000) self.current_model_type = "r1" # 默认使用 R1 - async def generate_response( - self, message: MessageThinking - ) -> Optional[Union[str, List[str]]]: + async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]: """根据当前模型类型选择对应的生成函数""" # 从global_config中获取模型概率值并选择模型 rand = random.random() if rand < global_config.MODEL_R1_PROBABILITY: self.current_model_type = "r1" current_model = self.model_r1 - elif ( - rand - < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY - ): + elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY: self.current_model_type = "v3" current_model = self.model_v3 else: @@ -57,37 +46,28 @@ class ResponseGenerator: logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中") - model_response = await self._generate_response_with_model( - message, current_model - ) + model_response = await self._generate_response_with_model(message, current_model) raw_content = model_response # print(f"raw_content: {raw_content}") # print(f"model_response: {model_response}") - + if model_response: - logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}') + logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}") model_response = await self._process_response(model_response) if model_response: return model_response, raw_content return None, raw_content - async def _generate_response_with_model( - self, message: MessageThinking, model: LLM_request - ) -> Optional[str]: + async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request) -> Optional[str]: """使用指定的模型生成回复""" - sender_name = ( - message.chat_stream.user_info.user_nickname - or f"用户{message.chat_stream.user_info.user_id}" - ) + sender_name = message.chat_stream.user_info.user_nickname or f"用户{message.chat_stream.user_info.user_id}" if message.chat_stream.user_info.user_cardname: sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}" # 获取关系值 relationship_value = ( - relationship_manager.get_relationship( - message.chat_stream - ).relationship_value + relationship_manager.get_relationship(message.chat_stream).relationship_value if relationship_manager.get_relationship(message.chat_stream) else 0.0 ) @@ -202,7 +182,7 @@ class ResponseGenerator: return None, [] processed_response = process_llm_response(content) - + # print(f"得到了处理后的llm返回{processed_response}") return processed_response @@ -212,13 +192,11 @@ class InitiativeMessageGenerate: def __init__(self): self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7) self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7) - self.model_r1_distill = LLM_request( - model=global_config.llm_reasoning_minor, temperature=0.7 - ) + self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7) def gen_response(self, message: Message): - topic_select_prompt, dots_for_select, prompt_template = ( - prompt_builder._build_initiative_prompt_select(message.group_id) + topic_select_prompt, dots_for_select, prompt_template = prompt_builder._build_initiative_prompt_select( + message.group_id ) content_select, reasoning = self.model_v3.generate_response(topic_select_prompt) logger.debug(f"{content_select} {reasoning}") @@ -230,16 +208,12 @@ class InitiativeMessageGenerate: return None else: return None - prompt_check, memory = prompt_builder._build_initiative_prompt_check( - select_dot[1], prompt_template - ) + prompt_check, memory = prompt_builder._build_initiative_prompt_check(select_dot[1], prompt_template) content_check, reasoning_check = self.model_v3.generate_response(prompt_check) logger.info(f"{content_check} {reasoning_check}") if "yes" not in content_check.lower(): return None - prompt = prompt_builder._build_initiative_prompt( - select_dot, prompt_template, memory - ) + prompt = prompt_builder._build_initiative_prompt(select_dot, prompt_template, memory) content, reasoning = self.model_r1.generate_response_async(prompt) logger.debug(f"[DEBUG] {content} {reasoning}") return content diff --git a/src/plugins/chat/message_cq.py b/src/plugins/chat/message_cq.py index 4c46d3bf2..435bdf19e 100644 --- a/src/plugins/chat/message_cq.py +++ b/src/plugins/chat/message_cq.py @@ -57,16 +57,20 @@ class MessageRecvCQ(MessageCQ): # 私聊消息不携带group_info if group_info is None: pass - elif group_info.group_name is None: group_info.group_name = get_groupname(group_info.group_id) # 解析消息段 - self.message_segment = self._parse_message(raw_message, reply_message) + self.message_segment = None # 初始化为None self.raw_message = raw_message + # 异步初始化在外部完成 - def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg: - """解析消息内容为Seg对象""" + async def initialize(self): + """异步初始化方法""" + self.message_segment = await self._parse_message(self.raw_message) + + async def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg: + """异步解析消息内容为Seg对象""" cq_code_dict_list = [] segments = [] @@ -98,9 +102,10 @@ class MessageRecvCQ(MessageCQ): # 转换CQ码为Seg对象 for code_item in cq_code_dict_list: - message_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message) - if message_obj.translated_segments: - segments.append(message_obj.translated_segments) + cq_code_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message) + await cq_code_obj.translate() # 异步调用translate + if cq_code_obj.translated_segments: + segments.append(cq_code_obj.translated_segments) # 如果只有一个segment,直接返回 if len(segments) == 1: @@ -133,9 +138,7 @@ class MessageSendCQ(MessageCQ): self.message_segment = message_segment self.raw_message = self._generate_raw_message() - def _generate_raw_message( - self, - ) -> str: + def _generate_raw_message(self) -> str: """将Seg对象转换为raw_message""" segments = [] diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index a41ed51e2..ec0dac3d0 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -14,24 +14,24 @@ from .chat_stream import chat_manager class PromptBuilder: def __init__(self): - self.prompt_built = '' - self.activate_messages = '' + self.prompt_built = "" + self.activate_messages = "" - - - async def _build_prompt(self, - message_txt: str, - sender_name: str = "某人", - relationship_value: float = 0.0, - stream_id: Optional[int] = None) -> tuple[str, str]: + async def _build_prompt( + self, + message_txt: str, + sender_name: str = "某人", + relationship_value: float = 0.0, + stream_id: Optional[int] = None, + ) -> tuple[str, str]: """构建prompt - + Args: message_txt: 消息文本 sender_name: 发送者昵称 relationship_value: 关系值 group_id: 群组ID - + Returns: str: 构建好的prompt """ @@ -56,46 +56,43 @@ class PromptBuilder: current_date = time.strftime("%Y-%m-%d", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime()) bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task() - prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' + prompt_date = f"""今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n""" # 知识构建 start_time = time.time() - prompt_info = '' - promt_info_prompt = '' + prompt_info = "" + promt_info_prompt = "" prompt_info = await self.get_prompt_info(message_txt, threshold=0.5) if prompt_info: - prompt_info = f'''你有以下这些[知识]:{prompt_info}请你记住上面的[ - 知识],之后可能会用到-''' + prompt_info = f"""你有以下这些[知识]:{prompt_info}请你记住上面的[ + 知识],之后可能会用到-""" end_time = time.time() logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}秒") # 获取聊天上下文 - chat_in_group=True - chat_talking_prompt = '' + chat_in_group = True + chat_talking_prompt = "" if stream_id: - chat_talking_prompt = get_recent_group_detailed_plain_text(stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) - chat_stream=chat_manager.get_stream(stream_id) + chat_talking_prompt = get_recent_group_detailed_plain_text( + stream_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True + ) + chat_stream = chat_manager.get_stream(stream_id) if chat_stream.group_info: chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" else: - chat_in_group=False + chat_in_group = False chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}" # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") - - - + # 使用新的记忆获取方法 - memory_prompt = '' + memory_prompt = "" start_time = time.time() # 调用 hippocampus 的 get_relevant_memories 方法 relevant_memories = await hippocampus.get_relevant_memories( - text=message_txt, - max_topics=5, - similarity_threshold=0.4, - max_memory_num=5 + text=message_txt, max_topics=5, similarity_threshold=0.4, max_memory_num=5 ) if relevant_memories: @@ -115,56 +112,58 @@ class PromptBuilder: logger.info(f"回忆耗时: {(end_time - start_time):.3f}秒") # 激活prompt构建 - activate_prompt = '' + activate_prompt = "" if chat_in_group: - activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}。" + activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}。" else: activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}。" # 关键词检测与反应 - keywords_reaction_prompt = '' + keywords_reaction_prompt = "" for rule in global_config.keywords_reaction_rules: if rule.get("enable", False): if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])): - logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}") - keywords_reaction_prompt += rule.get("reaction", "") + ',' - - #人格选择 - personality=global_config.PROMPT_PERSONALITY + logger.info( + f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}" + ) + keywords_reaction_prompt += rule.get("reaction", "") + "," + + # 人格选择 + personality = global_config.PROMPT_PERSONALITY probability_1 = global_config.PERSONALITY_1 probability_2 = global_config.PERSONALITY_2 probability_3 = global_config.PERSONALITY_3 - - prompt_personality = f'{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{"/".join(global_config.BOT_ALIAS_NAMES)},' + + prompt_personality = f"{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{'/'.join(global_config.BOT_ALIAS_NAMES)}," personality_choice = random.random() if chat_in_group: - prompt_in_group=f"你正在浏览{chat_stream.platform}群" + prompt_in_group = f"你正在浏览{chat_stream.platform}群" else: - prompt_in_group=f"你正在{chat_stream.platform}上和{sender_name}私聊" + prompt_in_group = f"你正在{chat_stream.platform}上和{sender_name}私聊" if personality_choice < probability_1: # 第一种人格 - prompt_personality += f'''{personality[0]}, 你正在浏览qq群,{promt_info_prompt}, + prompt_personality += f"""{personality[0]}, 你正在浏览qq群,{promt_info_prompt}, 现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{keywords_reaction_prompt} - 请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。''' + 请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。""" elif personality_choice < probability_1 + probability_2: # 第二种人格 - prompt_personality += f'''{personality[1]}, 你正在浏览qq群,{promt_info_prompt}, + prompt_personality += f"""{personality[1]}, 你正在浏览qq群,{promt_info_prompt}, 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt} - 请你表达自己的见解和观点。可以有个性。''' + 请你表达自己的见解和观点。可以有个性。""" else: # 第三种人格 - prompt_personality += f'''{personality[2]}, 你正在浏览qq群,{promt_info_prompt}, + prompt_personality += f"""{personality[2]}, 你正在浏览qq群,{promt_info_prompt}, 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt} - 请你表达自己的见解和观点。可以有个性。''' + 请你表达自己的见解和观点。可以有个性。""" # 中文高手(新加的好玩功能) - prompt_ger = '' + prompt_ger = "" if random.random() < 0.04: - prompt_ger += '你喜欢用倒装句' + prompt_ger += "你喜欢用倒装句" if random.random() < 0.02: - prompt_ger += '你喜欢用反问句' + prompt_ger += "你喜欢用反问句" if random.random() < 0.01: - prompt_ger += '你喜欢用文言文' + prompt_ger += "你喜欢用文言文" # 额外信息要求 - extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' + extra_info = """但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容""" # 合并prompt prompt = "" @@ -175,16 +174,16 @@ class PromptBuilder: prompt += f"{prompt_ger}\n" prompt += f"{extra_info}\n" - '''读空气prompt处理''' + """读空气prompt处理""" activate_prompt_check = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。" - prompt_personality_check = '' + prompt_personality_check = "" extra_check_info = f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复,如果自己正在和别人聊天一定要回复,其他话题如果合适搭话也可以回复,如果认为应该回复请输出yes,否则输出no,请注意是决定是否需要回复,而不是编写回复内容,除了yes和no不要输出任何回复内容。" if personality_choice < probability_1: # 第一种人格 - prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},{personality[0]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' + prompt_personality_check = f"""你的网名叫{global_config.BOT_NICKNAME},{personality[0]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}""" elif personality_choice < probability_1 + probability_2: # 第二种人格 - prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},{personality[1]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' + prompt_personality_check = f"""你的网名叫{global_config.BOT_NICKNAME},{personality[1]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}""" else: # 第三种人格 - prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME},{personality[2]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' + prompt_personality_check = f"""你的网名叫{global_config.BOT_NICKNAME},{personality[2]}, 你正在浏览qq群,{promt_info_prompt} {activate_prompt_check} {extra_check_info}""" prompt_check_if_response = f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}" @@ -194,38 +193,38 @@ class PromptBuilder: current_date = time.strftime("%Y-%m-%d", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime()) bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task() - prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' + prompt_date = f"""今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n""" - chat_talking_prompt = '' + chat_talking_prompt = "" if group_id: - chat_talking_prompt = get_recent_group_detailed_plain_text(group_id, - limit=global_config.MAX_CONTEXT_SIZE, - combine=True) + chat_talking_prompt = get_recent_group_detailed_plain_text( + group_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True + ) chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") # 获取主动发言的话题 all_nodes = memory_graph.dots - all_nodes = filter(lambda dot: len(dot[1]['memory_items']) > 3, all_nodes) + all_nodes = filter(lambda dot: len(dot[1]["memory_items"]) > 3, all_nodes) nodes_for_select = random.sample(all_nodes, 5) topics = [info[0] for info in nodes_for_select] infos = [info[1] for info in nodes_for_select] # 激活prompt构建 - activate_prompt = '' + activate_prompt = "" activate_prompt = "以上是群里正在进行的聊天。" personality = global_config.PROMPT_PERSONALITY - prompt_personality = '' + prompt_personality = "" personality_choice = random.random() if personality_choice < probability_1: # 第一种人格 - prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]}''' + prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]}""" elif personality_choice < probability_1 + probability_2: # 第二种人格 - prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]}''' + prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]}""" else: # 第三种人格 - prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[2]}''' + prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[2]}""" - topics_str = ','.join(f"\"{topics}\"") + topics_str = ",".join(f'"{topics}"') prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)" prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}" @@ -234,8 +233,8 @@ class PromptBuilder: return prompt_initiative_select, nodes_for_select, prompt_regular def _build_initiative_prompt_check(self, selected_node, prompt_regular): - memory = random.sample(selected_node['memory_items'], 3) - memory = '\n'.join(memory) + memory = random.sample(selected_node["memory_items"], 3) + memory = "\n".join(memory) prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,以这个作为主题发言合适吗?请在把握群里的聊天内容的基础上,综合群内的氛围,如果认为应该发言请输出yes,否则输出no,请注意是决定是否需要发言,而不是编写回复内容,除了yes和no不要输出任何回复内容。" return prompt_for_check, memory @@ -244,7 +243,7 @@ class PromptBuilder: return prompt_for_initiative async def get_prompt_info(self, message: str, threshold: float): - related_info = '' + related_info = "" logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") embedding = await get_embedding(message) related_info += self.get_info_from_db(embedding, threshold=threshold) @@ -253,7 +252,7 @@ class PromptBuilder: def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str: if not query_embedding: - return '' + return "" # 使用余弦相似度计算 pipeline = [ { @@ -265,12 +264,14 @@ class PromptBuilder: "in": { "$add": [ "$$value", - {"$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]} - ]} + { + "$multiply": [ + {"$arrayElemAt": ["$embedding", "$$this"]}, + {"$arrayElemAt": [query_embedding, "$$this"]}, + ] + }, ] - } + }, } }, "magnitude1": { @@ -278,7 +279,7 @@ class PromptBuilder: "$reduce": { "input": "$embedding", "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, } } }, @@ -287,19 +288,13 @@ class PromptBuilder: "$reduce": { "input": query_embedding, "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, } } - } - } - }, - { - "$addFields": { - "similarity": { - "$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}] - } + }, } }, + {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, { "$match": { "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 @@ -307,17 +302,17 @@ class PromptBuilder: }, {"$sort": {"similarity": -1}}, {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1}} + {"$project": {"content": 1, "similarity": 1}}, ] results = list(db.knowledges.aggregate(pipeline)) # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") if not results: - return '' + return "" # 返回所有找到的内容,用换行分隔 - return '\n'.join(str(result['content']) for result in results) + return "\n".join(str(result["content"]) for result in results) prompt_builder = PromptBuilder() diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index dd6d7d4d1..6d900ba54 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -34,7 +34,7 @@ class ImageManager: self._ensure_description_collection() self._ensure_image_dir() self._initialized = True - self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) + self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=1000) def _ensure_image_dir(self): """确保图像存储目录存在"""