fix: 增大了默认的maxtoken防止溢出,messagecq改异步get_image防止阻塞

This commit is contained in:
tcmofashi
2025-03-14 15:38:33 +08:00
parent e2c5d42634
commit d3fe02e467
7 changed files with 207 additions and 286 deletions

View File

@@ -74,6 +74,7 @@ class ChatBot:
reply_message=None, reply_message=None,
platform="qq", platform="qq",
) )
await message_cq.initialize()
message_json = message_cq.to_dict() message_json = message_cq.to_dict()
# 进入maimbot # 进入maimbot
@@ -121,7 +122,12 @@ class ChatBot:
if event.user_id in global_config.ban_user_id: if event.user_id in global_config.ban_user_id:
return return
if event.reply and hasattr(event.reply, 'sender') and hasattr(event.reply.sender, 'user_id') and event.reply.sender.user_id in global_config.ban_user_id: if (
event.reply
and hasattr(event.reply, "sender")
and hasattr(event.reply.sender, "user_id")
and event.reply.sender.user_id in global_config.ban_user_id
):
logger.debug(f"跳过处理回复来自被ban用户 {event.reply.sender.user_id} 的消息") logger.debug(f"跳过处理回复来自被ban用户 {event.reply.sender.user_id} 的消息")
return return
# 处理私聊消息 # 处理私聊消息
@@ -171,6 +177,7 @@ class ChatBot:
reply_message=event.reply, reply_message=event.reply,
platform="qq", platform="qq",
) )
await message_cq.initialize()
message_json = message_cq.to_dict() message_json = message_cq.to_dict()
# 进入maimbot # 进入maimbot

View File

@@ -1,48 +1,28 @@
import base64 import base64
import html import html
import time import time
import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import ssl
import os import os
import aiohttp
import requests
# 解析各种CQ码
# 包含CQ码类
import urllib3
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from urllib3.util import create_urllib3_context
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .mapper import emojimapper from .mapper import emojimapper
from .message_base import Seg from .message_base import Seg
from .utils_user import get_user_nickname,get_groupname from .utils_user import get_user_nickname, get_groupname
from .message_base import GroupInfo, UserInfo from .message_base import GroupInfo, UserInfo
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616 # 创建SSL上下文
ctx = create_urllib3_context() ssl_context = ssl.create_default_context()
ctx.load_default_certs() ssl_context.set_ciphers("AES128-GCM-SHA256")
ctx.set_ciphers("AES128-GCM-SHA256")
class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def __init__(self, ssl_context=None, **kwargs):
self.ssl_context = ssl_context
super().__init__(**kwargs)
def init_poolmanager(self, connections, maxsize, block=False):
self.poolmanager = urllib3.poolmanager.PoolManager(
num_pools=connections,
maxsize=maxsize,
block=block,
ssl_context=self.ssl_context,
)
@dataclass @dataclass
@@ -70,14 +50,12 @@ class CQCode:
"""初始化LLM实例""" """初始化LLM实例"""
pass pass
def translate(self): async def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象""" """根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == "text": if self.type == "text":
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data=self.params.get("text", ""))
type="text", data=self.params.get("text", "")
)
elif self.type == "image": elif self.type == "image":
base64_data = self.translate_image() base64_data = await self.translate_image()
if base64_data: if base64_data:
if self.params.get("sub_type") == "0": if self.params.get("sub_type") == "0":
self.translated_segments = Seg(type="image", data=base64_data) self.translated_segments = Seg(type="image", data=base64_data)
@@ -90,22 +68,18 @@ class CQCode:
self.translated_segments = Seg(type="text", data="@[全体成员]") self.translated_segments = Seg(type="text", data="@[全体成员]")
else: else:
user_nickname = get_user_nickname(self.params.get("qq", "")) user_nickname = get_user_nickname(self.params.get("qq", ""))
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data=f"[@{user_nickname or '某人'}]")
type="text", data=f"[@{user_nickname or '某人'}]"
)
elif self.type == "reply": elif self.type == "reply":
reply_segments = self.translate_reply() reply_segments = await self.translate_reply()
if reply_segments: if reply_segments:
self.translated_segments = Seg(type="seglist", data=reply_segments) self.translated_segments = Seg(type="seglist", data=reply_segments)
else: else:
self.translated_segments = Seg(type="text", data="[回复某人消息]") self.translated_segments = Seg(type="text", data="[回复某人消息]")
elif self.type == "face": elif self.type == "face":
face_id = self.params.get("id", "") face_id = self.params.get("id", "")
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]")
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
)
elif self.type == "forward": elif self.type == "forward":
forward_segments = self.translate_forward() forward_segments = await self.translate_forward()
if forward_segments: if forward_segments:
self.translated_segments = Seg(type="seglist", data=forward_segments) self.translated_segments = Seg(type="seglist", data=forward_segments)
else: else:
@@ -113,18 +87,8 @@ class CQCode:
else: else:
self.translated_segments = Seg(type="text", data=f"[{self.type}]") self.translated_segments = Seg(type="text", data=f"[{self.type}]")
def get_img(self): async def get_img(self) -> Optional[str]:
""" """异步获取图片并转换为base64"""
headers = {
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
'Accept': 'image/*;q=0.8',
'Accept-Encoding': 'gzip, deflate, br',
'Connection': 'keep-alive',
'Cache-Control': 'no-cache',
'Pragma': 'no-cache'
}
"""
# 腾讯专用请求头配置
headers = { headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36", "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
"Accept": "text/html, application/xhtml xml, */*", "Accept": "text/html, application/xhtml xml, */*",
@@ -133,61 +97,63 @@ class CQCode:
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
} }
url = html.unescape(self.params["url"]) url = html.unescape(self.params["url"])
if not url.startswith(("http://", "https://")): if not url.startswith(("http://", "https://")):
return None return None
# 创建专用会话
session = requests.session()
session.adapters.pop("https://", None)
session.mount("https://", TencentSSLAdapter(ctx))
max_retries = 3 max_retries = 3
for retry in range(max_retries): for retry in range(max_retries):
try: try:
response = session.get( logger.debug(f"获取图片中: {url}")
# 设置SSL上下文和创建连接器
conn = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=conn) as session:
async with session.get(
url, url,
headers=headers, headers=headers,
timeout=15, timeout=aiohttp.ClientTimeout(total=15),
allow_redirects=True, allow_redirects=True,
stream=True, # 流式传输避免大内存问题 ) as response:
)
# 腾讯服务器特殊状态码处理 # 腾讯服务器特殊状态码处理
if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url: if response.status == 400 and "multimedia.nt.qq.com.cn" in url:
return None return None
if response.status_code != 200: if response.status != 200:
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") raise aiohttp.ClientError(f"HTTP {response.status}")
# 验证内容类型 # 验证内容类型
content_type = response.headers.get("Content-Type", "") content_type = response.headers.get("Content-Type", "")
if not content_type.startswith("image/"): if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}") raise ValueError(f"非图片内容类型: {content_type}")
# 读取响应内容
content = await response.read()
logger.debug(f"获取图片成功: {url}")
# 转换为Base64 # 转换为Base64
image_base64 = base64.b64encode(response.content).decode("utf-8") image_base64 = base64.b64encode(content).decode("utf-8")
self.image_base64 = image_base64 self.image_base64 = image_base64
return image_base64 return image_base64
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e: except (aiohttp.ClientError, ValueError) as e:
if retry == max_retries - 1: if retry == max_retries - 1:
logger.error(f"最终请求失败: {str(e)}") logger.error(f"最终请求失败: {str(e)}")
time.sleep(1.5**retry) # 指数退避 await asyncio.sleep(1.5**retry) # 指数退避
except Exception: except Exception as e:
logger.exception("[未知错误]") logger.exception(f"获取图片时发生未知错误: {str(e)}")
return None return None
return None return None
def translate_image(self) -> Optional[str]: async def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串""" """处理图片类型的CQ码返回base64字符串"""
if "url" not in self.params: if "url" not in self.params:
return None return None
return self.get_img() return await self.get_img()
def translate_forward(self) -> Optional[List[Seg]]: async def translate_forward(self) -> Optional[List[Seg]]:
"""处理转发消息返回Seg列表""" """处理转发消息返回Seg列表"""
try: try:
if "content" not in self.params: if "content" not in self.params:
@@ -217,15 +183,16 @@ class CQCode:
else: else:
if raw_message: if raw_message:
from .message_cq import MessageRecvCQ from .message_cq import MessageRecvCQ
user_info=UserInfo(
platform='qq', user_info = UserInfo(
platform="qq",
user_id=msg.get("user_id", 0), user_id=msg.get("user_id", 0),
user_nickname=nickname, user_nickname=nickname,
) )
group_info=GroupInfo( group_info = GroupInfo(
platform='qq', platform="qq",
group_id=msg.get("group_id", 0), group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0)) group_name=get_groupname(msg.get("group_id", 0)),
) )
message_obj = MessageRecvCQ( message_obj = MessageRecvCQ(
@@ -235,24 +202,23 @@ class CQCode:
plain_text=raw_message, plain_text=raw_message,
group_info=group_info, group_info=group_info,
) )
content_seg = Seg( await message_obj.initialize()
type="seglist", data=[message_obj.message_segment] content_seg = Seg(type="seglist", data=[message_obj.message_segment])
)
else: else:
content_seg = Seg(type="text", data="[空消息]") content_seg = Seg(type="text", data="[空消息]")
else: else:
if raw_message: if raw_message:
from .message_cq import MessageRecvCQ from .message_cq import MessageRecvCQ
user_info=UserInfo( user_info = UserInfo(
platform='qq', platform="qq",
user_id=msg.get("user_id", 0), user_id=msg.get("user_id", 0),
user_nickname=nickname, user_nickname=nickname,
) )
group_info=GroupInfo( group_info = GroupInfo(
platform='qq', platform="qq",
group_id=msg.get("group_id", 0), group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0)) group_name=get_groupname(msg.get("group_id", 0)),
) )
message_obj = MessageRecvCQ( message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0), message_id=msg.get("message_id", 0),
@@ -261,9 +227,8 @@ class CQCode:
plain_text=raw_message, plain_text=raw_message,
group_info=group_info, group_info=group_info,
) )
content_seg = Seg( await message_obj.initialize()
type="seglist", data=[message_obj.message_segment] content_seg = Seg(type="seglist", data=[message_obj.message_segment])
)
else: else:
content_seg = Seg(type="text", data="[空消息]") content_seg = Seg(type="text", data="[空消息]")
@@ -277,7 +242,7 @@ class CQCode:
logger.error(f"处理转发消息失败: {str(e)}") logger.error(f"处理转发消息失败: {str(e)}")
return None return None
def translate_reply(self) -> Optional[List[Seg]]: async def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码返回Seg列表""" """处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ from .message_cq import MessageRecvCQ
@@ -285,22 +250,19 @@ class CQCode:
return None return None
if self.reply_message.sender.user_id: if self.reply_message.sender.user_id:
message_obj = MessageRecvCQ( message_obj = MessageRecvCQ(
user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.nickname), user_info=UserInfo(
user_id=self.reply_message.sender.user_id, user_nickname=self.reply_message.sender.nickname
),
message_id=self.reply_message.message_id, message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message), raw_message=str(self.reply_message.message),
group_info=GroupInfo(group_id=self.reply_message.group_id), group_info=GroupInfo(group_id=self.reply_message.group_id),
) )
await message_obj.initialize()
segments = [] segments = []
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ: if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
segments.append( segments.append(Seg(type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "))
Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
)
)
else: else:
segments.append( segments.append(
Seg( Seg(
@@ -318,16 +280,12 @@ class CQCode:
@staticmethod @staticmethod
def unescape(text: str) -> str: def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符""" """反转义CQ码中的特殊字符"""
return ( return text.replace(",", ",").replace("[", "[").replace("]", "]").replace("&", "&")
text.replace(",", ",")
.replace("[", "[")
.replace("]", "]")
.replace("&", "&")
)
class CQCode_tool: class CQCode_tool:
@staticmethod @staticmethod
def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode: def cq_from_dict_to_class(cq_code: Dict, msg, reply: Optional[Dict] = None) -> CQCode:
""" """
将CQ码字典转换为CQCode对象 将CQ码字典转换为CQCode对象
@@ -353,11 +311,9 @@ class CQCode_tool:
params=params, params=params,
group_info=msg.message_info.group_info, group_info=msg.message_info.group_info,
user_info=msg.message_info.user_info, user_info=msg.message_info.user_info,
reply_message=reply reply_message=reply,
) )
# 进行翻译处理
instance.translate()
return instance return instance
@staticmethod @staticmethod
@@ -383,12 +339,7 @@ class CQCode_tool:
# 确保使用绝对路径 # 确保使用绝对路径
abs_path = os.path.abspath(file_path) abs_path = os.path.abspath(file_path)
# 转义特殊字符 # 转义特殊字符
escaped_path = ( escaped_path = abs_path.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
abs_path.replace("&", "&")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
)
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@@ -403,10 +354,7 @@ class CQCode_tool:
""" """
# 转义base64数据 # 转义base64数据
escaped_base64 = ( escaped_base64 = (
base64_data.replace("&", "&") base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
) )
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]" return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
@@ -422,10 +370,7 @@ class CQCode_tool:
""" """
# 转义base64数据 # 转义base64数据
escaped_base64 = ( escaped_base64 = (
base64_data.replace("&", "&") base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
) )
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]" return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"

View File

@@ -37,7 +37,7 @@ class EmojiManager:
self._scan_task = None self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request( self.llm_emotion_judge = LLM_request(
model=global_config.llm_emotion_judge, max_tokens=60, temperature=0.8 model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8
) # 更高的温度更少的token后续可以根据情绪来调整温度 ) # 更高的温度更少的token后续可以根据情绪来调整温度
def _ensure_emoji_dir(self): def _ensure_emoji_dir(self):
@@ -275,9 +275,6 @@ class EmojiManager:
continue continue
logger.info(f"check通过 {check}") logger.info(f"check通过 {check}")
if description is not None:
embedding = await get_embedding(description)
if description is not None: if description is not None:
embedding = await get_embedding(description) embedding = await get_embedding(description)

View File

@@ -25,30 +25,19 @@ class ResponseGenerator:
max_tokens=1000, max_tokens=1000,
stream=True, stream=True,
) )
self.model_v3 = LLM_request( self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7, max_tokens=3000)
model=global_config.llm_normal, temperature=0.7, max_tokens=1000 self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=3000)
) self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7, max_tokens=3000)
self.model_r1_distill = LLM_request(
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
)
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.current_model_type = "r1" # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
async def generate_response( async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
self, message: MessageThinking
) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数""" """根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型 # 从global_config中获取模型概率值并选择模型
rand = random.random() rand = random.random()
if rand < global_config.MODEL_R1_PROBABILITY: if rand < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = "r1" self.current_model_type = "r1"
current_model = self.model_r1 current_model = self.model_r1
elif ( elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
rand
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
):
self.current_model_type = "v3" self.current_model_type = "v3"
current_model = self.model_v3 current_model = self.model_v3
else: else:
@@ -57,37 +46,28 @@ class ResponseGenerator:
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中") logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
model_response = await self._generate_response_with_model( model_response = await self._generate_response_with_model(message, current_model)
message, current_model
)
raw_content = model_response raw_content = model_response
# print(f"raw_content: {raw_content}") # print(f"raw_content: {raw_content}")
# print(f"model_response: {model_response}") # print(f"model_response: {model_response}")
if model_response: if model_response:
logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}') logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
model_response = await self._process_response(model_response) model_response = await self._process_response(model_response)
if model_response: if model_response:
return model_response, raw_content return model_response, raw_content
return None, raw_content return None, raw_content
async def _generate_response_with_model( async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request) -> Optional[str]:
self, message: MessageThinking, model: LLM_request
) -> Optional[str]:
"""使用指定的模型生成回复""" """使用指定的模型生成回复"""
sender_name = ( sender_name = message.chat_stream.user_info.user_nickname or f"用户{message.chat_stream.user_info.user_id}"
message.chat_stream.user_info.user_nickname
or f"用户{message.chat_stream.user_info.user_id}"
)
if message.chat_stream.user_info.user_cardname: if message.chat_stream.user_info.user_cardname:
sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}" sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}"
# 获取关系值 # 获取关系值
relationship_value = ( relationship_value = (
relationship_manager.get_relationship( relationship_manager.get_relationship(message.chat_stream).relationship_value
message.chat_stream
).relationship_value
if relationship_manager.get_relationship(message.chat_stream) if relationship_manager.get_relationship(message.chat_stream)
else 0.0 else 0.0
) )
@@ -212,13 +192,11 @@ class InitiativeMessageGenerate:
def __init__(self): def __init__(self):
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7) self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7) self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
self.model_r1_distill = LLM_request( self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7)
model=global_config.llm_reasoning_minor, temperature=0.7
)
def gen_response(self, message: Message): def gen_response(self, message: Message):
topic_select_prompt, dots_for_select, prompt_template = ( topic_select_prompt, dots_for_select, prompt_template = prompt_builder._build_initiative_prompt_select(
prompt_builder._build_initiative_prompt_select(message.group_id) message.group_id
) )
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt) content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
logger.debug(f"{content_select} {reasoning}") logger.debug(f"{content_select} {reasoning}")
@@ -230,16 +208,12 @@ class InitiativeMessageGenerate:
return None return None
else: else:
return None return None
prompt_check, memory = prompt_builder._build_initiative_prompt_check( prompt_check, memory = prompt_builder._build_initiative_prompt_check(select_dot[1], prompt_template)
select_dot[1], prompt_template
)
content_check, reasoning_check = self.model_v3.generate_response(prompt_check) content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
logger.info(f"{content_check} {reasoning_check}") logger.info(f"{content_check} {reasoning_check}")
if "yes" not in content_check.lower(): if "yes" not in content_check.lower():
return None return None
prompt = prompt_builder._build_initiative_prompt( prompt = prompt_builder._build_initiative_prompt(select_dot, prompt_template, memory)
select_dot, prompt_template, memory
)
content, reasoning = self.model_r1.generate_response_async(prompt) content, reasoning = self.model_r1.generate_response_async(prompt)
logger.debug(f"[DEBUG] {content} {reasoning}") logger.debug(f"[DEBUG] {content} {reasoning}")
return content return content

View File

@@ -57,16 +57,20 @@ class MessageRecvCQ(MessageCQ):
# 私聊消息不携带group_info # 私聊消息不携带group_info
if group_info is None: if group_info is None:
pass pass
elif group_info.group_name is None: elif group_info.group_name is None:
group_info.group_name = get_groupname(group_info.group_id) group_info.group_name = get_groupname(group_info.group_id)
# 解析消息段 # 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message) self.message_segment = None # 初始化为None
self.raw_message = raw_message self.raw_message = raw_message
# 异步初始化在外部完成
def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg: async def initialize(self):
"""解析消息内容为Seg对象""" """异步初始化方法"""
self.message_segment = await self._parse_message(self.raw_message)
async def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
"""异步解析消息内容为Seg对象"""
cq_code_dict_list = [] cq_code_dict_list = []
segments = [] segments = []
@@ -98,9 +102,10 @@ class MessageRecvCQ(MessageCQ):
# 转换CQ码为Seg对象 # 转换CQ码为Seg对象
for code_item in cq_code_dict_list: for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message) cq_code_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message)
if message_obj.translated_segments: await cq_code_obj.translate() # 异步调用translate
segments.append(message_obj.translated_segments) if cq_code_obj.translated_segments:
segments.append(cq_code_obj.translated_segments)
# 如果只有一个segment直接返回 # 如果只有一个segment直接返回
if len(segments) == 1: if len(segments) == 1:
@@ -133,9 +138,7 @@ class MessageSendCQ(MessageCQ):
self.message_segment = message_segment self.message_segment = message_segment
self.raw_message = self._generate_raw_message() self.raw_message = self._generate_raw_message()
def _generate_raw_message( def _generate_raw_message(self) -> str:
self,
) -> str:
"""将Seg对象转换为raw_message""" """将Seg对象转换为raw_message"""
segments = [] segments = []

View File

@@ -14,16 +14,16 @@ from .chat_stream import chat_manager
class PromptBuilder: class PromptBuilder:
def __init__(self): def __init__(self):
self.prompt_built = '' self.prompt_built = ""
self.activate_messages = '' self.activate_messages = ""
async def _build_prompt(
self,
async def _build_prompt(self,
message_txt: str, message_txt: str,
sender_name: str = "某人", sender_name: str = "某人",
relationship_value: float = 0.0, relationship_value: float = 0.0,
stream_id: Optional[int] = None) -> tuple[str, str]: stream_id: Optional[int] = None,
) -> tuple[str, str]:
"""构建prompt """构建prompt
Args: Args:
@@ -56,46 +56,43 @@ class PromptBuilder:
current_date = time.strftime("%Y-%m-%d", time.localtime()) current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task() bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' prompt_date = f"""今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n"""
# 知识构建 # 知识构建
start_time = time.time() start_time = time.time()
prompt_info = '' prompt_info = ""
promt_info_prompt = '' promt_info_prompt = ""
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5) prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
if prompt_info: if prompt_info:
prompt_info = f'''你有以下这些[知识]{prompt_info}请你记住上面的[ prompt_info = f"""你有以下这些[知识]{prompt_info}请你记住上面的[
知识],之后可能会用到-''' 知识],之后可能会用到-"""
end_time = time.time() end_time = time.time()
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}") logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}")
# 获取聊天上下文 # 获取聊天上下文
chat_in_group=True chat_in_group = True
chat_talking_prompt = '' chat_talking_prompt = ""
if stream_id: if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_talking_prompt = get_recent_group_detailed_plain_text(
chat_stream=chat_manager.get_stream(stream_id) stream_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True
)
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream.group_info: if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
else: else:
chat_in_group=False chat_in_group = False
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 使用新的记忆获取方法 # 使用新的记忆获取方法
memory_prompt = '' memory_prompt = ""
start_time = time.time() start_time = time.time()
# 调用 hippocampus 的 get_relevant_memories 方法 # 调用 hippocampus 的 get_relevant_memories 方法
relevant_memories = await hippocampus.get_relevant_memories( relevant_memories = await hippocampus.get_relevant_memories(
text=message_txt, text=message_txt, max_topics=5, similarity_threshold=0.4, max_memory_num=5
max_topics=5,
similarity_threshold=0.4,
max_memory_num=5
) )
if relevant_memories: if relevant_memories:
@@ -115,56 +112,58 @@ class PromptBuilder:
logger.info(f"回忆耗时: {(end_time - start_time):.3f}") logger.info(f"回忆耗时: {(end_time - start_time):.3f}")
# 激活prompt构建 # 激活prompt构建
activate_prompt = '' activate_prompt = ""
if chat_in_group: if chat_in_group:
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}" activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
else: else:
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}" activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
# 关键词检测与反应 # 关键词检测与反应
keywords_reaction_prompt = '' keywords_reaction_prompt = ""
for rule in global_config.keywords_reaction_rules: for rule in global_config.keywords_reaction_rules:
if rule.get("enable", False): if rule.get("enable", False):
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])): if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}") logger.info(
keywords_reaction_prompt += rule.get("reaction", "") + '' f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}"
)
keywords_reaction_prompt += rule.get("reaction", "") + ""
#人格选择 # 人格选择
personality=global_config.PROMPT_PERSONALITY personality = global_config.PROMPT_PERSONALITY
probability_1 = global_config.PERSONALITY_1 probability_1 = global_config.PERSONALITY_1
probability_2 = global_config.PERSONALITY_2 probability_2 = global_config.PERSONALITY_2
probability_3 = global_config.PERSONALITY_3 probability_3 = global_config.PERSONALITY_3
prompt_personality = f'{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{"/".join(global_config.BOT_ALIAS_NAMES)}' prompt_personality = f"{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{'/'.join(global_config.BOT_ALIAS_NAMES)}"
personality_choice = random.random() personality_choice = random.random()
if chat_in_group: if chat_in_group:
prompt_in_group=f"你正在浏览{chat_stream.platform}" prompt_in_group = f"你正在浏览{chat_stream.platform}"
else: else:
prompt_in_group=f"你正在{chat_stream.platform}上和{sender_name}私聊" prompt_in_group = f"你正在{chat_stream.platform}上和{sender_name}私聊"
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
prompt_personality += f'''{personality[0]}, 你正在浏览qq群,{promt_info_prompt}, prompt_personality += f"""{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{keywords_reaction_prompt} 现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{keywords_reaction_prompt}
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。''' 请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。"""
elif personality_choice < probability_1 + probability_2: # 第二种人格 elif personality_choice < probability_1 + probability_2: # 第二种人格
prompt_personality += f'''{personality[1]}, 你正在浏览qq群{promt_info_prompt}, prompt_personality += f"""{personality[1]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt} 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt}
请你表达自己的见解和观点。可以有个性。''' 请你表达自己的见解和观点。可以有个性。"""
else: # 第三种人格 else: # 第三种人格
prompt_personality += f'''{personality[2]}, 你正在浏览qq群{promt_info_prompt}, prompt_personality += f"""{personality[2]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt} 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt}
请你表达自己的见解和观点。可以有个性。''' 请你表达自己的见解和观点。可以有个性。"""
# 中文高手(新加的好玩功能) # 中文高手(新加的好玩功能)
prompt_ger = '' prompt_ger = ""
if random.random() < 0.04: if random.random() < 0.04:
prompt_ger += '你喜欢用倒装句' prompt_ger += "你喜欢用倒装句"
if random.random() < 0.02: if random.random() < 0.02:
prompt_ger += '你喜欢用反问句' prompt_ger += "你喜欢用反问句"
if random.random() < 0.01: if random.random() < 0.01:
prompt_ger += '你喜欢用文言文' prompt_ger += "你喜欢用文言文"
# 额外信息要求 # 额外信息要求
extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' extra_info = """但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容"""
# 合并prompt # 合并prompt
prompt = "" prompt = ""
@@ -175,16 +174,16 @@ class PromptBuilder:
prompt += f"{prompt_ger}\n" prompt += f"{prompt_ger}\n"
prompt += f"{extra_info}\n" prompt += f"{extra_info}\n"
'''读空气prompt处理''' """读空气prompt处理"""
activate_prompt_check = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。" activate_prompt_check = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt_personality_check = '' prompt_personality_check = ""
extra_check_info = f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。" extra_check_info = f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' prompt_personality_check = f"""你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}"""
elif personality_choice < probability_1 + probability_2: # 第二种人格 elif personality_choice < probability_1 + probability_2: # 第二种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[1]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' prompt_personality_check = f"""你的网名叫{global_config.BOT_NICKNAME}{personality[1]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}"""
else: # 第三种人格 else: # 第三种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' prompt_personality_check = f"""你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}"""
prompt_check_if_response = f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}" prompt_check_if_response = f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
@@ -194,38 +193,38 @@ class PromptBuilder:
current_date = time.strftime("%Y-%m-%d", time.localtime()) current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task() bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' prompt_date = f"""今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n"""
chat_talking_prompt = '' chat_talking_prompt = ""
if group_id: if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(group_id, chat_talking_prompt = get_recent_group_detailed_plain_text(
limit=global_config.MAX_CONTEXT_SIZE, group_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True
combine=True) )
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 获取主动发言的话题 # 获取主动发言的话题
all_nodes = memory_graph.dots all_nodes = memory_graph.dots
all_nodes = filter(lambda dot: len(dot[1]['memory_items']) > 3, all_nodes) all_nodes = filter(lambda dot: len(dot[1]["memory_items"]) > 3, all_nodes)
nodes_for_select = random.sample(all_nodes, 5) nodes_for_select = random.sample(all_nodes, 5)
topics = [info[0] for info in nodes_for_select] topics = [info[0] for info in nodes_for_select]
infos = [info[1] for info in nodes_for_select] infos = [info[1] for info in nodes_for_select]
# 激活prompt构建 # 激活prompt构建
activate_prompt = '' activate_prompt = ""
activate_prompt = "以上是群里正在进行的聊天。" activate_prompt = "以上是群里正在进行的聊天。"
personality = global_config.PROMPT_PERSONALITY personality = global_config.PROMPT_PERSONALITY
prompt_personality = '' prompt_personality = ""
personality_choice = random.random() personality_choice = random.random()
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}''' prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}"""
elif personality_choice < probability_1 + probability_2: # 第二种人格 elif personality_choice < probability_1 + probability_2: # 第二种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}''' prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}"""
else: # 第三种人格 else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}''' prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}"""
topics_str = ','.join(f"\"{topics}\"") topics_str = ",".join(f'"{topics}"')
prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)" prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}" prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
@@ -234,8 +233,8 @@ class PromptBuilder:
return prompt_initiative_select, nodes_for_select, prompt_regular return prompt_initiative_select, nodes_for_select, prompt_regular
def _build_initiative_prompt_check(self, selected_node, prompt_regular): def _build_initiative_prompt_check(self, selected_node, prompt_regular):
memory = random.sample(selected_node['memory_items'], 3) memory = random.sample(selected_node["memory_items"], 3)
memory = '\n'.join(memory) memory = "\n".join(memory)
prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。" prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check, memory return prompt_for_check, memory
@@ -244,7 +243,7 @@ class PromptBuilder:
return prompt_for_initiative return prompt_for_initiative
async def get_prompt_info(self, message: str, threshold: float): async def get_prompt_info(self, message: str, threshold: float):
related_info = '' related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message) embedding = await get_embedding(message)
related_info += self.get_info_from_db(embedding, threshold=threshold) related_info += self.get_info_from_db(embedding, threshold=threshold)
@@ -253,7 +252,7 @@ class PromptBuilder:
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str: def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
if not query_embedding: if not query_embedding:
return '' return ""
# 使用余弦相似度计算 # 使用余弦相似度计算
pipeline = [ pipeline = [
{ {
@@ -265,12 +264,14 @@ class PromptBuilder:
"in": { "in": {
"$add": [ "$add": [
"$$value", "$$value",
{"$multiply": [ {
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]}, {"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]} {"$arrayElemAt": [query_embedding, "$$this"]},
]}
] ]
} },
]
},
} }
}, },
"magnitude1": { "magnitude1": {
@@ -278,7 +279,7 @@ class PromptBuilder:
"$reduce": { "$reduce": {
"input": "$embedding", "input": "$embedding",
"initialValue": 0, "initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
} }
} }
}, },
@@ -287,19 +288,13 @@ class PromptBuilder:
"$reduce": { "$reduce": {
"input": query_embedding, "input": query_embedding,
"initialValue": 0, "initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
} }
} }
}, },
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
} }
}, },
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{ {
"$match": { "$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
@@ -307,17 +302,17 @@ class PromptBuilder:
}, },
{"$sort": {"similarity": -1}}, {"$sort": {"similarity": -1}},
{"$limit": limit}, {"$limit": limit},
{"$project": {"content": 1, "similarity": 1}} {"$project": {"content": 1, "similarity": 1}},
] ]
results = list(db.knowledges.aggregate(pipeline)) results = list(db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results: if not results:
return '' return ""
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return '\n'.join(str(result['content']) for result in results) return "\n".join(str(result["content"]) for result in results)
prompt_builder = PromptBuilder() prompt_builder = PromptBuilder()

View File

@@ -34,7 +34,7 @@ class ImageManager:
self._ensure_description_collection() self._ensure_description_collection()
self._ensure_image_dir() self._ensure_image_dir()
self._initialized = True self._initialized = True
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=1000)
def _ensure_image_dir(self): def _ensure_image_dir(self):
"""确保图像存储目录存在""" """确保图像存储目录存在"""