Merge branch 'main-fix' into relationship
This commit is contained in:
11
bot.py
11
bot.py
@@ -10,9 +10,8 @@ import uvicorn
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from nonebot.adapters.onebot.v11 import Adapter
|
from nonebot.adapters.onebot.v11 import Adapter
|
||||||
import platform
|
import platform
|
||||||
from src.plugins.utils.logger_config import setup_logger
|
from src.plugins.utils.logger_config import LogModule, LogClassification
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
# 配置日志格式
|
# 配置日志格式
|
||||||
|
|
||||||
@@ -102,7 +101,9 @@ def load_env():
|
|||||||
|
|
||||||
|
|
||||||
def load_logger():
|
def load_logger():
|
||||||
setup_logger()
|
global logger # 使得bot.py中其他函数也能调用
|
||||||
|
log_module = LogModule()
|
||||||
|
logger = log_module.setup_logger(LogClassification.BASE)
|
||||||
|
|
||||||
|
|
||||||
def scan_provider(env_config: dict):
|
def scan_provider(env_config: dict):
|
||||||
@@ -174,8 +175,6 @@ def raw_main():
|
|||||||
if platform.system().lower() != "windows":
|
if platform.system().lower() != "windows":
|
||||||
time.tzset()
|
time.tzset()
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
load_logger()
|
|
||||||
easter_egg()
|
easter_egg()
|
||||||
init_config()
|
init_config()
|
||||||
init_env()
|
init_env()
|
||||||
@@ -207,6 +206,8 @@ def raw_main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
|
# 配置日志,使得主程序直接退出时候也能访问logger
|
||||||
|
load_logger()
|
||||||
raw_main()
|
raw_main()
|
||||||
|
|
||||||
app = nonebot.get_asgi()
|
app = nonebot.get_asgi()
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from nonebot.adapters.onebot.v11 import (
|
|||||||
PokeNotifyEvent,
|
PokeNotifyEvent,
|
||||||
GroupRecallNoticeEvent,
|
GroupRecallNoticeEvent,
|
||||||
FriendRecallNoticeEvent,
|
FriendRecallNoticeEvent,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..memory_system.memory import hippocampus
|
from ..memory_system.memory import hippocampus
|
||||||
@@ -32,10 +31,11 @@ from .utils_image import image_path_to_base64
|
|||||||
from .utils_user import get_user_nickname, get_user_cardname, get_groupname
|
from .utils_user import get_user_nickname, get_user_cardname, get_groupname
|
||||||
from .willing_manager import willing_manager # 导入意愿管理器
|
from .willing_manager import willing_manager # 导入意愿管理器
|
||||||
from .message_base import UserInfo, GroupInfo, Seg
|
from .message_base import UserInfo, GroupInfo, Seg
|
||||||
from ..utils.logger_config import setup_logger, LogModule
|
from ..utils.logger_config import LogClassification, LogModule
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logger = setup_logger(LogModule.CHAT)
|
log_module = LogModule()
|
||||||
|
logger = log_module.setup_logger(LogClassification.CHAT)
|
||||||
|
|
||||||
|
|
||||||
class ChatBot:
|
class ChatBot:
|
||||||
@@ -95,8 +95,6 @@ class ChatBot:
|
|||||||
|
|
||||||
await self.storage.store_recalled_message(event.message_id, time.time(), chat)
|
await self.storage.store_recalled_message(event.message_id, time.time(), chat)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_message(self, event: MessageEvent, bot: Bot) -> None:
|
async def handle_message(self, event: MessageEvent, bot: Bot) -> None:
|
||||||
"""处理收到的消息"""
|
"""处理收到的消息"""
|
||||||
|
|
||||||
@@ -161,6 +159,7 @@ class ChatBot:
|
|||||||
reply_message=event.reply,
|
reply_message=event.reply,
|
||||||
platform="qq",
|
platform="qq",
|
||||||
)
|
)
|
||||||
|
await message_cq.initialize()
|
||||||
message_json = message_cq.to_dict()
|
message_json = message_cq.to_dict()
|
||||||
|
|
||||||
# 进入maimbot
|
# 进入maimbot
|
||||||
@@ -387,6 +386,7 @@ class ChatBot:
|
|||||||
reply_message=None,
|
reply_message=None,
|
||||||
platform="qq",
|
platform="qq",
|
||||||
)
|
)
|
||||||
|
await message_cq.initialize()
|
||||||
message_json = message_cq.to_dict()
|
message_json = message_cq.to_dict()
|
||||||
|
|
||||||
message = MessageRecv(message_json)
|
message = MessageRecv(message_json)
|
||||||
|
|||||||
@@ -1,48 +1,28 @@
|
|||||||
import base64
|
import base64
|
||||||
import html
|
import html
|
||||||
import time
|
import time
|
||||||
|
import asyncio
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
import ssl
|
||||||
import os
|
import os
|
||||||
|
import aiohttp
|
||||||
import requests
|
|
||||||
|
|
||||||
# 解析各种CQ码
|
|
||||||
# 包含CQ码类
|
|
||||||
import urllib3
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
from urllib3.util import create_urllib3_context
|
|
||||||
|
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .mapper import emojimapper
|
from .mapper import emojimapper
|
||||||
from .message_base import Seg
|
from .message_base import Seg
|
||||||
from .utils_user import get_user_nickname,get_groupname
|
from .utils_user import get_user_nickname, get_groupname
|
||||||
from .message_base import GroupInfo, UserInfo
|
from .message_base import GroupInfo, UserInfo
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616
|
# 创建SSL上下文
|
||||||
ctx = create_urllib3_context()
|
ssl_context = ssl.create_default_context()
|
||||||
ctx.load_default_certs()
|
ssl_context.set_ciphers("AES128-GCM-SHA256")
|
||||||
ctx.set_ciphers("AES128-GCM-SHA256")
|
|
||||||
|
|
||||||
|
|
||||||
class TencentSSLAdapter(requests.adapters.HTTPAdapter):
|
|
||||||
def __init__(self, ssl_context=None, **kwargs):
|
|
||||||
self.ssl_context = ssl_context
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def init_poolmanager(self, connections, maxsize, block=False):
|
|
||||||
self.poolmanager = urllib3.poolmanager.PoolManager(
|
|
||||||
num_pools=connections,
|
|
||||||
maxsize=maxsize,
|
|
||||||
block=block,
|
|
||||||
ssl_context=self.ssl_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -70,14 +50,12 @@ class CQCode:
|
|||||||
"""初始化LLM实例"""
|
"""初始化LLM实例"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def translate(self):
|
async def translate(self):
|
||||||
"""根据CQ码类型进行相应的翻译处理,转换为Seg对象"""
|
"""根据CQ码类型进行相应的翻译处理,转换为Seg对象"""
|
||||||
if self.type == "text":
|
if self.type == "text":
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(type="text", data=self.params.get("text", ""))
|
||||||
type="text", data=self.params.get("text", "")
|
|
||||||
)
|
|
||||||
elif self.type == "image":
|
elif self.type == "image":
|
||||||
base64_data = self.translate_image()
|
base64_data = await self.translate_image()
|
||||||
if base64_data:
|
if base64_data:
|
||||||
if self.params.get("sub_type") == "0":
|
if self.params.get("sub_type") == "0":
|
||||||
self.translated_segments = Seg(type="image", data=base64_data)
|
self.translated_segments = Seg(type="image", data=base64_data)
|
||||||
@@ -90,22 +68,18 @@ class CQCode:
|
|||||||
self.translated_segments = Seg(type="text", data="@[全体成员]")
|
self.translated_segments = Seg(type="text", data="@[全体成员]")
|
||||||
else:
|
else:
|
||||||
user_nickname = get_user_nickname(self.params.get("qq", ""))
|
user_nickname = get_user_nickname(self.params.get("qq", ""))
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(type="text", data=f"[@{user_nickname or '某人'}]")
|
||||||
type="text", data=f"[@{user_nickname or '某人'}]"
|
|
||||||
)
|
|
||||||
elif self.type == "reply":
|
elif self.type == "reply":
|
||||||
reply_segments = self.translate_reply()
|
reply_segments = await self.translate_reply()
|
||||||
if reply_segments:
|
if reply_segments:
|
||||||
self.translated_segments = Seg(type="seglist", data=reply_segments)
|
self.translated_segments = Seg(type="seglist", data=reply_segments)
|
||||||
else:
|
else:
|
||||||
self.translated_segments = Seg(type="text", data="[回复某人消息]")
|
self.translated_segments = Seg(type="text", data="[回复某人消息]")
|
||||||
elif self.type == "face":
|
elif self.type == "face":
|
||||||
face_id = self.params.get("id", "")
|
face_id = self.params.get("id", "")
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]")
|
||||||
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
|
|
||||||
)
|
|
||||||
elif self.type == "forward":
|
elif self.type == "forward":
|
||||||
forward_segments = self.translate_forward()
|
forward_segments = await self.translate_forward()
|
||||||
if forward_segments:
|
if forward_segments:
|
||||||
self.translated_segments = Seg(type="seglist", data=forward_segments)
|
self.translated_segments = Seg(type="seglist", data=forward_segments)
|
||||||
else:
|
else:
|
||||||
@@ -113,18 +87,8 @@ class CQCode:
|
|||||||
else:
|
else:
|
||||||
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
|
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
|
||||||
|
|
||||||
def get_img(self):
|
async def get_img(self) -> Optional[str]:
|
||||||
"""
|
"""异步获取图片并转换为base64"""
|
||||||
headers = {
|
|
||||||
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
|
|
||||||
'Accept': 'image/*;q=0.8',
|
|
||||||
'Accept-Encoding': 'gzip, deflate, br',
|
|
||||||
'Connection': 'keep-alive',
|
|
||||||
'Cache-Control': 'no-cache',
|
|
||||||
'Pragma': 'no-cache'
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# 腾讯专用请求头配置
|
|
||||||
headers = {
|
headers = {
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
|
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
|
||||||
"Accept": "text/html, application/xhtml xml, */*",
|
"Accept": "text/html, application/xhtml xml, */*",
|
||||||
@@ -133,61 +97,63 @@ class CQCode:
|
|||||||
"Content-Type": "application/x-www-form-urlencoded",
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
}
|
}
|
||||||
|
|
||||||
url = html.unescape(self.params["url"])
|
url = html.unescape(self.params["url"])
|
||||||
if not url.startswith(("http://", "https://")):
|
if not url.startswith(("http://", "https://")):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 创建专用会话
|
|
||||||
session = requests.session()
|
|
||||||
session.adapters.pop("https://", None)
|
|
||||||
session.mount("https://", TencentSSLAdapter(ctx))
|
|
||||||
|
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
for retry in range(max_retries):
|
for retry in range(max_retries):
|
||||||
try:
|
try:
|
||||||
response = session.get(
|
logger.debug(f"获取图片中: {url}")
|
||||||
url,
|
# 设置SSL上下文和创建连接器
|
||||||
headers=headers,
|
conn = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
timeout=15,
|
async with aiohttp.ClientSession(connector=conn) as session:
|
||||||
allow_redirects=True,
|
async with session.get(
|
||||||
stream=True, # 流式传输避免大内存问题
|
url,
|
||||||
)
|
headers=headers,
|
||||||
|
timeout=aiohttp.ClientTimeout(total=15),
|
||||||
|
allow_redirects=True,
|
||||||
|
) as response:
|
||||||
|
# 腾讯服务器特殊状态码处理
|
||||||
|
if response.status == 400 and "multimedia.nt.qq.com.cn" in url:
|
||||||
|
return None
|
||||||
|
|
||||||
# 腾讯服务器特殊状态码处理
|
if response.status != 200:
|
||||||
if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url:
|
raise aiohttp.ClientError(f"HTTP {response.status}")
|
||||||
return None
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
# 验证内容类型
|
||||||
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
|
content_type = response.headers.get("Content-Type", "")
|
||||||
|
if not content_type.startswith("image/"):
|
||||||
|
raise ValueError(f"非图片内容类型: {content_type}")
|
||||||
|
|
||||||
# 验证内容类型
|
# 读取响应内容
|
||||||
content_type = response.headers.get("Content-Type", "")
|
content = await response.read()
|
||||||
if not content_type.startswith("image/"):
|
logger.debug(f"获取图片成功: {url}")
|
||||||
raise ValueError(f"非图片内容类型: {content_type}")
|
|
||||||
|
|
||||||
# 转换为Base64
|
# 转换为Base64
|
||||||
image_base64 = base64.b64encode(response.content).decode("utf-8")
|
image_base64 = base64.b64encode(content).decode("utf-8")
|
||||||
self.image_base64 = image_base64
|
self.image_base64 = image_base64
|
||||||
return image_base64
|
return image_base64
|
||||||
|
|
||||||
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
|
except (aiohttp.ClientError, ValueError) as e:
|
||||||
if retry == max_retries - 1:
|
if retry == max_retries - 1:
|
||||||
logger.error(f"最终请求失败: {str(e)}")
|
logger.error(f"最终请求失败: {str(e)}")
|
||||||
time.sleep(1.5**retry) # 指数退避
|
await asyncio.sleep(1.5**retry) # 指数退避
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("[未知错误]")
|
logger.exception(f"获取图片时发生未知错误: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def translate_image(self) -> Optional[str]:
|
async def translate_image(self) -> Optional[str]:
|
||||||
"""处理图片类型的CQ码,返回base64字符串"""
|
"""处理图片类型的CQ码,返回base64字符串"""
|
||||||
if "url" not in self.params:
|
if "url" not in self.params:
|
||||||
return None
|
return None
|
||||||
return self.get_img()
|
return await self.get_img()
|
||||||
|
|
||||||
def translate_forward(self) -> Optional[List[Seg]]:
|
async def translate_forward(self) -> Optional[List[Seg]]:
|
||||||
"""处理转发消息,返回Seg列表"""
|
"""处理转发消息,返回Seg列表"""
|
||||||
try:
|
try:
|
||||||
if "content" not in self.params:
|
if "content" not in self.params:
|
||||||
@@ -217,15 +183,16 @@ class CQCode:
|
|||||||
else:
|
else:
|
||||||
if raw_message:
|
if raw_message:
|
||||||
from .message_cq import MessageRecvCQ
|
from .message_cq import MessageRecvCQ
|
||||||
user_info=UserInfo(
|
|
||||||
platform='qq',
|
user_info = UserInfo(
|
||||||
|
platform="qq",
|
||||||
user_id=msg.get("user_id", 0),
|
user_id=msg.get("user_id", 0),
|
||||||
user_nickname=nickname,
|
user_nickname=nickname,
|
||||||
)
|
)
|
||||||
group_info=GroupInfo(
|
group_info = GroupInfo(
|
||||||
platform='qq',
|
platform="qq",
|
||||||
group_id=msg.get("group_id", 0),
|
group_id=msg.get("group_id", 0),
|
||||||
group_name=get_groupname(msg.get("group_id", 0))
|
group_name=get_groupname(msg.get("group_id", 0)),
|
||||||
)
|
)
|
||||||
|
|
||||||
message_obj = MessageRecvCQ(
|
message_obj = MessageRecvCQ(
|
||||||
@@ -235,24 +202,23 @@ class CQCode:
|
|||||||
plain_text=raw_message,
|
plain_text=raw_message,
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
content_seg = Seg(
|
await message_obj.initialize()
|
||||||
type="seglist", data=[message_obj.message_segment]
|
content_seg = Seg(type="seglist", data=[message_obj.message_segment])
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
content_seg = Seg(type="text", data="[空消息]")
|
content_seg = Seg(type="text", data="[空消息]")
|
||||||
else:
|
else:
|
||||||
if raw_message:
|
if raw_message:
|
||||||
from .message_cq import MessageRecvCQ
|
from .message_cq import MessageRecvCQ
|
||||||
|
|
||||||
user_info=UserInfo(
|
user_info = UserInfo(
|
||||||
platform='qq',
|
platform="qq",
|
||||||
user_id=msg.get("user_id", 0),
|
user_id=msg.get("user_id", 0),
|
||||||
user_nickname=nickname,
|
user_nickname=nickname,
|
||||||
)
|
)
|
||||||
group_info=GroupInfo(
|
group_info = GroupInfo(
|
||||||
platform='qq',
|
platform="qq",
|
||||||
group_id=msg.get("group_id", 0),
|
group_id=msg.get("group_id", 0),
|
||||||
group_name=get_groupname(msg.get("group_id", 0))
|
group_name=get_groupname(msg.get("group_id", 0)),
|
||||||
)
|
)
|
||||||
message_obj = MessageRecvCQ(
|
message_obj = MessageRecvCQ(
|
||||||
message_id=msg.get("message_id", 0),
|
message_id=msg.get("message_id", 0),
|
||||||
@@ -261,9 +227,8 @@ class CQCode:
|
|||||||
plain_text=raw_message,
|
plain_text=raw_message,
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
content_seg = Seg(
|
await message_obj.initialize()
|
||||||
type="seglist", data=[message_obj.message_segment]
|
content_seg = Seg(type="seglist", data=[message_obj.message_segment])
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
content_seg = Seg(type="text", data="[空消息]")
|
content_seg = Seg(type="text", data="[空消息]")
|
||||||
|
|
||||||
@@ -277,7 +242,7 @@ class CQCode:
|
|||||||
logger.error(f"处理转发消息失败: {str(e)}")
|
logger.error(f"处理转发消息失败: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def translate_reply(self) -> Optional[List[Seg]]:
|
async def translate_reply(self) -> Optional[List[Seg]]:
|
||||||
"""处理回复类型的CQ码,返回Seg列表"""
|
"""处理回复类型的CQ码,返回Seg列表"""
|
||||||
from .message_cq import MessageRecvCQ
|
from .message_cq import MessageRecvCQ
|
||||||
|
|
||||||
@@ -285,22 +250,19 @@ class CQCode:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if self.reply_message.sender.user_id:
|
if self.reply_message.sender.user_id:
|
||||||
|
|
||||||
message_obj = MessageRecvCQ(
|
message_obj = MessageRecvCQ(
|
||||||
user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.nickname),
|
user_info=UserInfo(
|
||||||
|
user_id=self.reply_message.sender.user_id, user_nickname=self.reply_message.sender.nickname
|
||||||
|
),
|
||||||
message_id=self.reply_message.message_id,
|
message_id=self.reply_message.message_id,
|
||||||
raw_message=str(self.reply_message.message),
|
raw_message=str(self.reply_message.message),
|
||||||
group_info=GroupInfo(group_id=self.reply_message.group_id),
|
group_info=GroupInfo(group_id=self.reply_message.group_id),
|
||||||
)
|
)
|
||||||
|
await message_obj.initialize()
|
||||||
|
|
||||||
segments = []
|
segments = []
|
||||||
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
|
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
|
||||||
segments.append(
|
segments.append(Seg(type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "))
|
||||||
Seg(
|
|
||||||
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
segments.append(
|
segments.append(
|
||||||
Seg(
|
Seg(
|
||||||
@@ -318,16 +280,12 @@ class CQCode:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def unescape(text: str) -> str:
|
def unescape(text: str) -> str:
|
||||||
"""反转义CQ码中的特殊字符"""
|
"""反转义CQ码中的特殊字符"""
|
||||||
return (
|
return text.replace(",", ",").replace("[", "[").replace("]", "]").replace("&", "&")
|
||||||
text.replace(",", ",")
|
|
||||||
.replace("[", "[")
|
|
||||||
.replace("]", "]")
|
|
||||||
.replace("&", "&")
|
|
||||||
)
|
|
||||||
|
|
||||||
class CQCode_tool:
|
class CQCode_tool:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode:
|
def cq_from_dict_to_class(cq_code: Dict, msg, reply: Optional[Dict] = None) -> CQCode:
|
||||||
"""
|
"""
|
||||||
将CQ码字典转换为CQCode对象
|
将CQ码字典转换为CQCode对象
|
||||||
|
|
||||||
@@ -353,11 +311,9 @@ class CQCode_tool:
|
|||||||
params=params,
|
params=params,
|
||||||
group_info=msg.message_info.group_info,
|
group_info=msg.message_info.group_info,
|
||||||
user_info=msg.message_info.user_info,
|
user_info=msg.message_info.user_info,
|
||||||
reply_message=reply
|
reply_message=reply,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 进行翻译处理
|
|
||||||
instance.translate()
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -383,12 +339,7 @@ class CQCode_tool:
|
|||||||
# 确保使用绝对路径
|
# 确保使用绝对路径
|
||||||
abs_path = os.path.abspath(file_path)
|
abs_path = os.path.abspath(file_path)
|
||||||
# 转义特殊字符
|
# 转义特殊字符
|
||||||
escaped_path = (
|
escaped_path = abs_path.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
|
||||||
abs_path.replace("&", "&")
|
|
||||||
.replace("[", "[")
|
|
||||||
.replace("]", "]")
|
|
||||||
.replace(",", ",")
|
|
||||||
)
|
|
||||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||||
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
||||||
|
|
||||||
@@ -403,10 +354,7 @@ class CQCode_tool:
|
|||||||
"""
|
"""
|
||||||
# 转义base64数据
|
# 转义base64数据
|
||||||
escaped_base64 = (
|
escaped_base64 = (
|
||||||
base64_data.replace("&", "&")
|
base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
|
||||||
.replace("[", "[")
|
|
||||||
.replace("]", "]")
|
|
||||||
.replace(",", ",")
|
|
||||||
)
|
)
|
||||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||||
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
|
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
|
||||||
@@ -422,10 +370,7 @@ class CQCode_tool:
|
|||||||
"""
|
"""
|
||||||
# 转义base64数据
|
# 转义base64数据
|
||||||
escaped_base64 = (
|
escaped_base64 = (
|
||||||
base64_data.replace("&", "&")
|
base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
|
||||||
.replace("[", "[")
|
|
||||||
.replace("]", "]")
|
|
||||||
.replace(",", ",")
|
|
||||||
)
|
)
|
||||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||||
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
|
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
|
||||||
|
|||||||
@@ -18,17 +18,17 @@ from ..chat.utils import get_embedding
|
|||||||
from ..chat.utils_image import ImageManager, image_path_to_base64
|
from ..chat.utils_image import ImageManager, image_path_to_base64
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
from ..utils.logger_config import setup_logger, LogModule
|
from ..utils.logger_config import LogClassification, LogModule
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logger = setup_logger(LogModule.EMOJI)
|
log_module = LogModule()
|
||||||
|
logger = log_module.setup_logger(LogClassification.EMOJI)
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
image_manager = ImageManager()
|
image_manager = ImageManager()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EmojiManager:
|
class EmojiManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
EMOJI_DIR = os.path.join("data", "emoji") # 表情包存储目录
|
EMOJI_DIR = os.path.join("data", "emoji") # 表情包存储目录
|
||||||
@@ -43,7 +43,7 @@ class EmojiManager:
|
|||||||
self._scan_task = None
|
self._scan_task = None
|
||||||
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
|
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
|
||||||
self.llm_emotion_judge = LLM_request(
|
self.llm_emotion_judge = LLM_request(
|
||||||
model=global_config.llm_emotion_judge, max_tokens=60, temperature=0.8
|
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8
|
||||||
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||||
|
|
||||||
def _ensure_emoji_dir(self):
|
def _ensure_emoji_dir(self):
|
||||||
@@ -281,7 +281,6 @@ class EmojiManager:
|
|||||||
|
|
||||||
if description is not None:
|
if description is not None:
|
||||||
embedding = await get_embedding(description)
|
embedding = await get_embedding(description)
|
||||||
|
|
||||||
# 准备数据库记录
|
# 准备数据库记录
|
||||||
emoji_record = {
|
emoji_record = {
|
||||||
"filename": filename,
|
"filename": filename,
|
||||||
|
|||||||
@@ -25,30 +25,19 @@ class ResponseGenerator:
|
|||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
self.model_v3 = LLM_request(
|
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7, max_tokens=3000)
|
||||||
model=global_config.llm_normal, temperature=0.7, max_tokens=1000
|
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=3000)
|
||||||
)
|
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7, max_tokens=3000)
|
||||||
self.model_r1_distill = LLM_request(
|
|
||||||
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
|
|
||||||
)
|
|
||||||
self.model_v25 = LLM_request(
|
|
||||||
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
|
|
||||||
)
|
|
||||||
self.current_model_type = "r1" # 默认使用 R1
|
self.current_model_type = "r1" # 默认使用 R1
|
||||||
|
|
||||||
async def generate_response(
|
async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
|
||||||
self, message: MessageThinking
|
|
||||||
) -> Optional[Union[str, List[str]]]:
|
|
||||||
"""根据当前模型类型选择对应的生成函数"""
|
"""根据当前模型类型选择对应的生成函数"""
|
||||||
# 从global_config中获取模型概率值并选择模型
|
# 从global_config中获取模型概率值并选择模型
|
||||||
rand = random.random()
|
rand = random.random()
|
||||||
if rand < global_config.MODEL_R1_PROBABILITY:
|
if rand < global_config.MODEL_R1_PROBABILITY:
|
||||||
self.current_model_type = "r1"
|
self.current_model_type = "r1"
|
||||||
current_model = self.model_r1
|
current_model = self.model_r1
|
||||||
elif (
|
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
|
||||||
rand
|
|
||||||
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
|
|
||||||
):
|
|
||||||
self.current_model_type = "v3"
|
self.current_model_type = "v3"
|
||||||
current_model = self.model_v3
|
current_model = self.model_v3
|
||||||
else:
|
else:
|
||||||
@@ -57,24 +46,20 @@ class ResponseGenerator:
|
|||||||
|
|
||||||
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
|
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
|
||||||
|
|
||||||
model_response = await self._generate_response_with_model(
|
model_response = await self._generate_response_with_model(message, current_model)
|
||||||
message, current_model
|
|
||||||
)
|
|
||||||
raw_content = model_response
|
raw_content = model_response
|
||||||
|
|
||||||
# print(f"raw_content: {raw_content}")
|
# print(f"raw_content: {raw_content}")
|
||||||
# print(f"model_response: {model_response}")
|
# print(f"model_response: {model_response}")
|
||||||
|
|
||||||
if model_response:
|
if model_response:
|
||||||
logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
|
logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
|
||||||
model_response = await self._process_response(model_response)
|
model_response = await self._process_response(model_response)
|
||||||
if model_response:
|
if model_response:
|
||||||
return model_response, raw_content
|
return model_response, raw_content
|
||||||
return None, raw_content
|
return None, raw_content
|
||||||
|
|
||||||
async def _generate_response_with_model(
|
async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request) -> Optional[str]:
|
||||||
self, message: MessageThinking, model: LLM_request
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""使用指定的模型生成回复"""
|
"""使用指定的模型生成回复"""
|
||||||
sender_name = ""
|
sender_name = ""
|
||||||
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
|
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
|
||||||
@@ -229,13 +214,11 @@ class InitiativeMessageGenerate:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
|
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
|
||||||
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
|
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
|
||||||
self.model_r1_distill = LLM_request(
|
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7)
|
||||||
model=global_config.llm_reasoning_minor, temperature=0.7
|
|
||||||
)
|
|
||||||
|
|
||||||
def gen_response(self, message: Message):
|
def gen_response(self, message: Message):
|
||||||
topic_select_prompt, dots_for_select, prompt_template = (
|
topic_select_prompt, dots_for_select, prompt_template = prompt_builder._build_initiative_prompt_select(
|
||||||
prompt_builder._build_initiative_prompt_select(message.group_id)
|
message.group_id
|
||||||
)
|
)
|
||||||
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
|
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
|
||||||
logger.debug(f"{content_select} {reasoning}")
|
logger.debug(f"{content_select} {reasoning}")
|
||||||
@@ -247,16 +230,12 @@ class InitiativeMessageGenerate:
|
|||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
prompt_check, memory = prompt_builder._build_initiative_prompt_check(
|
prompt_check, memory = prompt_builder._build_initiative_prompt_check(select_dot[1], prompt_template)
|
||||||
select_dot[1], prompt_template
|
|
||||||
)
|
|
||||||
content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
|
content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
|
||||||
logger.info(f"{content_check} {reasoning_check}")
|
logger.info(f"{content_check} {reasoning_check}")
|
||||||
if "yes" not in content_check.lower():
|
if "yes" not in content_check.lower():
|
||||||
return None
|
return None
|
||||||
prompt = prompt_builder._build_initiative_prompt(
|
prompt = prompt_builder._build_initiative_prompt(select_dot, prompt_template, memory)
|
||||||
select_dot, prompt_template, memory
|
|
||||||
)
|
|
||||||
content, reasoning = self.model_r1.generate_response_async(prompt)
|
content, reasoning = self.model_r1.generate_response_async(prompt)
|
||||||
logger.debug(f"[DEBUG] {content} {reasoning}")
|
logger.debug(f"[DEBUG] {content} {reasoning}")
|
||||||
return content
|
return content
|
||||||
|
|||||||
@@ -329,6 +329,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
self.message_segment,
|
self.message_segment,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
async def process(self) -> None:
|
async def process(self) -> None:
|
||||||
"""处理消息内容,生成纯文本和详细文本"""
|
"""处理消息内容,生成纯文本和详细文本"""
|
||||||
|
|||||||
@@ -57,16 +57,20 @@ class MessageRecvCQ(MessageCQ):
|
|||||||
# 私聊消息不携带group_info
|
# 私聊消息不携带group_info
|
||||||
if group_info is None:
|
if group_info is None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elif group_info.group_name is None:
|
elif group_info.group_name is None:
|
||||||
group_info.group_name = get_groupname(group_info.group_id)
|
group_info.group_name = get_groupname(group_info.group_id)
|
||||||
|
|
||||||
# 解析消息段
|
# 解析消息段
|
||||||
self.message_segment = self._parse_message(raw_message, reply_message)
|
self.message_segment = None # 初始化为None
|
||||||
self.raw_message = raw_message
|
self.raw_message = raw_message
|
||||||
|
# 异步初始化在外部完成
|
||||||
|
|
||||||
def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
|
async def initialize(self):
|
||||||
"""解析消息内容为Seg对象"""
|
"""异步初始化方法"""
|
||||||
|
self.message_segment = await self._parse_message(self.raw_message)
|
||||||
|
|
||||||
|
async def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
|
||||||
|
"""异步解析消息内容为Seg对象"""
|
||||||
cq_code_dict_list = []
|
cq_code_dict_list = []
|
||||||
segments = []
|
segments = []
|
||||||
|
|
||||||
@@ -98,9 +102,10 @@ class MessageRecvCQ(MessageCQ):
|
|||||||
|
|
||||||
# 转换CQ码为Seg对象
|
# 转换CQ码为Seg对象
|
||||||
for code_item in cq_code_dict_list:
|
for code_item in cq_code_dict_list:
|
||||||
message_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message)
|
cq_code_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message)
|
||||||
if message_obj.translated_segments:
|
await cq_code_obj.translate() # 异步调用translate
|
||||||
segments.append(message_obj.translated_segments)
|
if cq_code_obj.translated_segments:
|
||||||
|
segments.append(cq_code_obj.translated_segments)
|
||||||
|
|
||||||
# 如果只有一个segment,直接返回
|
# 如果只有一个segment,直接返回
|
||||||
if len(segments) == 1:
|
if len(segments) == 1:
|
||||||
@@ -133,9 +138,7 @@ class MessageSendCQ(MessageCQ):
|
|||||||
self.message_segment = message_segment
|
self.message_segment = message_segment
|
||||||
self.raw_message = self._generate_raw_message()
|
self.raw_message = self._generate_raw_message()
|
||||||
|
|
||||||
def _generate_raw_message(
|
def _generate_raw_message(self) -> str:
|
||||||
self,
|
|
||||||
) -> str:
|
|
||||||
"""将Seg对象转换为raw_message"""
|
"""将Seg对象转换为raw_message"""
|
||||||
segments = []
|
segments = []
|
||||||
|
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ from .relationship_manager import relationship_manager
|
|||||||
|
|
||||||
class PromptBuilder:
|
class PromptBuilder:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.prompt_built = ''
|
self.prompt_built = ""
|
||||||
self.activate_messages = ''
|
self.activate_messages = ""
|
||||||
|
|
||||||
async def _build_prompt(self,
|
async def _build_prompt(self,
|
||||||
chat_stream,
|
chat_stream,
|
||||||
@@ -88,46 +88,43 @@ class PromptBuilder:
|
|||||||
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
||||||
current_time = time.strftime("%H:%M:%S", time.localtime())
|
current_time = time.strftime("%H:%M:%S", time.localtime())
|
||||||
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
|
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
|
||||||
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
|
prompt_date = f"""今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n"""
|
||||||
|
|
||||||
# 知识构建
|
# 知识构建
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
prompt_info = ''
|
prompt_info = ""
|
||||||
promt_info_prompt = ''
|
promt_info_prompt = ""
|
||||||
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
|
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
|
||||||
if prompt_info:
|
if prompt_info:
|
||||||
prompt_info = f'''你有以下这些[知识]:{prompt_info}请你记住上面的[
|
prompt_info = f"""你有以下这些[知识]:{prompt_info}请你记住上面的[
|
||||||
知识],之后可能会用到-'''
|
知识],之后可能会用到-"""
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}秒")
|
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}秒")
|
||||||
|
|
||||||
# 获取聊天上下文
|
# 获取聊天上下文
|
||||||
chat_in_group=True
|
chat_in_group = True
|
||||||
chat_talking_prompt = ''
|
chat_talking_prompt = ""
|
||||||
if stream_id:
|
if stream_id:
|
||||||
chat_talking_prompt = get_recent_group_detailed_plain_text(stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
|
chat_talking_prompt = get_recent_group_detailed_plain_text(
|
||||||
chat_stream=chat_manager.get_stream(stream_id)
|
stream_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True
|
||||||
|
)
|
||||||
|
chat_stream = chat_manager.get_stream(stream_id)
|
||||||
if chat_stream.group_info:
|
if chat_stream.group_info:
|
||||||
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
||||||
else:
|
else:
|
||||||
chat_in_group=False
|
chat_in_group = False
|
||||||
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
|
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
|
||||||
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 使用新的记忆获取方法
|
# 使用新的记忆获取方法
|
||||||
memory_prompt = ''
|
memory_prompt = ""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 调用 hippocampus 的 get_relevant_memories 方法
|
# 调用 hippocampus 的 get_relevant_memories 方法
|
||||||
relevant_memories = await hippocampus.get_relevant_memories(
|
relevant_memories = await hippocampus.get_relevant_memories(
|
||||||
text=message_txt,
|
text=message_txt, max_topics=5, similarity_threshold=0.4, max_memory_num=5
|
||||||
max_topics=5,
|
|
||||||
similarity_threshold=0.4,
|
|
||||||
max_memory_num=5
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if relevant_memories:
|
if relevant_memories:
|
||||||
@@ -147,7 +144,7 @@ class PromptBuilder:
|
|||||||
logger.info(f"回忆耗时: {(end_time - start_time):.3f}秒")
|
logger.info(f"回忆耗时: {(end_time - start_time):.3f}秒")
|
||||||
|
|
||||||
# 激活prompt构建
|
# 激活prompt构建
|
||||||
activate_prompt = ''
|
activate_prompt = ""
|
||||||
if chat_in_group:
|
if chat_in_group:
|
||||||
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt},\
|
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt},\
|
||||||
{relation_prompt}{relation_prompt2}现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意。请分析聊天记录,根据你和他的关系和态度进行回复,明确你的立场和情感。"
|
{relation_prompt}{relation_prompt2}现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意。请分析聊天记录,根据你和他的关系和态度进行回复,明确你的立场和情感。"
|
||||||
@@ -155,20 +152,22 @@ class PromptBuilder:
|
|||||||
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,{relation_prompt}{mood_prompt},你的回复态度是{relation_prompt2}"
|
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,{relation_prompt}{mood_prompt},你的回复态度是{relation_prompt2}"
|
||||||
|
|
||||||
# 关键词检测与反应
|
# 关键词检测与反应
|
||||||
keywords_reaction_prompt = ''
|
keywords_reaction_prompt = ""
|
||||||
for rule in global_config.keywords_reaction_rules:
|
for rule in global_config.keywords_reaction_rules:
|
||||||
if rule.get("enable", False):
|
if rule.get("enable", False):
|
||||||
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
|
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
|
||||||
logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
|
logger.info(
|
||||||
keywords_reaction_prompt += rule.get("reaction", "") + ','
|
f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}"
|
||||||
|
)
|
||||||
|
keywords_reaction_prompt += rule.get("reaction", "") + ","
|
||||||
|
|
||||||
#人格选择
|
# 人格选择
|
||||||
personality=global_config.PROMPT_PERSONALITY
|
personality = global_config.PROMPT_PERSONALITY
|
||||||
probability_1 = global_config.PERSONALITY_1
|
probability_1 = global_config.PERSONALITY_1
|
||||||
probability_2 = global_config.PERSONALITY_2
|
probability_2 = global_config.PERSONALITY_2
|
||||||
probability_3 = global_config.PERSONALITY_3
|
probability_3 = global_config.PERSONALITY_3
|
||||||
|
|
||||||
prompt_personality = f'{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{"/".join(global_config.BOT_ALIAS_NAMES)},'
|
prompt_personality = f"{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{'/'.join(global_config.BOT_ALIAS_NAMES)},"
|
||||||
personality_choice = random.random()
|
personality_choice = random.random()
|
||||||
|
|
||||||
if personality_choice < probability_1: # 第一种人格
|
if personality_choice < probability_1: # 第一种人格
|
||||||
@@ -185,13 +184,13 @@ class PromptBuilder:
|
|||||||
请你表达自己的见解和观点。可以有个性。'''
|
请你表达自己的见解和观点。可以有个性。'''
|
||||||
|
|
||||||
# 中文高手(新加的好玩功能)
|
# 中文高手(新加的好玩功能)
|
||||||
prompt_ger = ''
|
prompt_ger = ""
|
||||||
if random.random() < 0.04:
|
if random.random() < 0.04:
|
||||||
prompt_ger += '你喜欢用倒装句'
|
prompt_ger += "你喜欢用倒装句"
|
||||||
if random.random() < 0.02:
|
if random.random() < 0.02:
|
||||||
prompt_ger += '你喜欢用反问句'
|
prompt_ger += "你喜欢用反问句"
|
||||||
if random.random() < 0.01:
|
if random.random() < 0.01:
|
||||||
prompt_ger += '你喜欢用文言文'
|
prompt_ger += "你喜欢用文言文"
|
||||||
|
|
||||||
# 额外信息要求
|
# 额外信息要求
|
||||||
extra_info = f'''但是记得你的回复态度和你的立场,切记你回复的人是{sender_name},不要输出你的思考过程,只需要输出最终的回复,务必简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
|
extra_info = f'''但是记得你的回复态度和你的立场,切记你回复的人是{sender_name},不要输出你的思考过程,只需要输出最终的回复,务必简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
|
||||||
@@ -225,38 +224,38 @@ class PromptBuilder:
|
|||||||
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
||||||
current_time = time.strftime("%H:%M:%S", time.localtime())
|
current_time = time.strftime("%H:%M:%S", time.localtime())
|
||||||
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
|
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
|
||||||
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
|
prompt_date = f"""今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n"""
|
||||||
|
|
||||||
chat_talking_prompt = ''
|
chat_talking_prompt = ""
|
||||||
if group_id:
|
if group_id:
|
||||||
chat_talking_prompt = get_recent_group_detailed_plain_text(group_id,
|
chat_talking_prompt = get_recent_group_detailed_plain_text(
|
||||||
limit=global_config.MAX_CONTEXT_SIZE,
|
group_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True
|
||||||
combine=True)
|
)
|
||||||
|
|
||||||
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
||||||
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
||||||
|
|
||||||
# 获取主动发言的话题
|
# 获取主动发言的话题
|
||||||
all_nodes = memory_graph.dots
|
all_nodes = memory_graph.dots
|
||||||
all_nodes = filter(lambda dot: len(dot[1]['memory_items']) > 3, all_nodes)
|
all_nodes = filter(lambda dot: len(dot[1]["memory_items"]) > 3, all_nodes)
|
||||||
nodes_for_select = random.sample(all_nodes, 5)
|
nodes_for_select = random.sample(all_nodes, 5)
|
||||||
topics = [info[0] for info in nodes_for_select]
|
topics = [info[0] for info in nodes_for_select]
|
||||||
infos = [info[1] for info in nodes_for_select]
|
infos = [info[1] for info in nodes_for_select]
|
||||||
|
|
||||||
# 激活prompt构建
|
# 激活prompt构建
|
||||||
activate_prompt = ''
|
activate_prompt = ""
|
||||||
activate_prompt = "以上是群里正在进行的聊天。"
|
activate_prompt = "以上是群里正在进行的聊天。"
|
||||||
personality = global_config.PROMPT_PERSONALITY
|
personality = global_config.PROMPT_PERSONALITY
|
||||||
prompt_personality = ''
|
prompt_personality = ""
|
||||||
personality_choice = random.random()
|
personality_choice = random.random()
|
||||||
if personality_choice < probability_1: # 第一种人格
|
if personality_choice < probability_1: # 第一种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]}'''
|
prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]}"""
|
||||||
elif personality_choice < probability_1 + probability_2: # 第二种人格
|
elif personality_choice < probability_1 + probability_2: # 第二种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]}'''
|
prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]}"""
|
||||||
else: # 第三种人格
|
else: # 第三种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[2]}'''
|
prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[2]}"""
|
||||||
|
|
||||||
topics_str = ','.join(f"\"{topics}\"")
|
topics_str = ",".join(f'"{topics}"')
|
||||||
prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
|
prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
|
||||||
|
|
||||||
prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
|
prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
|
||||||
@@ -265,8 +264,8 @@ class PromptBuilder:
|
|||||||
return prompt_initiative_select, nodes_for_select, prompt_regular
|
return prompt_initiative_select, nodes_for_select, prompt_regular
|
||||||
|
|
||||||
def _build_initiative_prompt_check(self, selected_node, prompt_regular):
|
def _build_initiative_prompt_check(self, selected_node, prompt_regular):
|
||||||
memory = random.sample(selected_node['memory_items'], 3)
|
memory = random.sample(selected_node["memory_items"], 3)
|
||||||
memory = '\n'.join(memory)
|
memory = "\n".join(memory)
|
||||||
prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,以这个作为主题发言合适吗?请在把握群里的聊天内容的基础上,综合群内的氛围,如果认为应该发言请输出yes,否则输出no,请注意是决定是否需要发言,而不是编写回复内容,除了yes和no不要输出任何回复内容。"
|
prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,以这个作为主题发言合适吗?请在把握群里的聊天内容的基础上,综合群内的氛围,如果认为应该发言请输出yes,否则输出no,请注意是决定是否需要发言,而不是编写回复内容,除了yes和no不要输出任何回复内容。"
|
||||||
return prompt_for_check, memory
|
return prompt_for_check, memory
|
||||||
|
|
||||||
@@ -275,7 +274,7 @@ class PromptBuilder:
|
|||||||
return prompt_for_initiative
|
return prompt_for_initiative
|
||||||
|
|
||||||
async def get_prompt_info(self, message: str, threshold: float):
|
async def get_prompt_info(self, message: str, threshold: float):
|
||||||
related_info = ''
|
related_info = ""
|
||||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
embedding = await get_embedding(message)
|
embedding = await get_embedding(message)
|
||||||
related_info += self.get_info_from_db(embedding, threshold=threshold)
|
related_info += self.get_info_from_db(embedding, threshold=threshold)
|
||||||
@@ -284,7 +283,7 @@ class PromptBuilder:
|
|||||||
|
|
||||||
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
|
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
|
||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
return ''
|
return ""
|
||||||
# 使用余弦相似度计算
|
# 使用余弦相似度计算
|
||||||
pipeline = [
|
pipeline = [
|
||||||
{
|
{
|
||||||
@@ -296,12 +295,14 @@ class PromptBuilder:
|
|||||||
"in": {
|
"in": {
|
||||||
"$add": [
|
"$add": [
|
||||||
"$$value",
|
"$$value",
|
||||||
{"$multiply": [
|
{
|
||||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
"$multiply": [
|
||||||
{"$arrayElemAt": [query_embedding, "$$this"]}
|
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||||
]}
|
{"$arrayElemAt": [query_embedding, "$$this"]},
|
||||||
|
]
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"magnitude1": {
|
"magnitude1": {
|
||||||
@@ -309,7 +310,7 @@ class PromptBuilder:
|
|||||||
"$reduce": {
|
"$reduce": {
|
||||||
"input": "$embedding",
|
"input": "$embedding",
|
||||||
"initialValue": 0,
|
"initialValue": 0,
|
||||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -318,19 +319,13 @@ class PromptBuilder:
|
|||||||
"$reduce": {
|
"$reduce": {
|
||||||
"input": query_embedding,
|
"input": query_embedding,
|
||||||
"initialValue": 0,
|
"initialValue": 0,
|
||||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$addFields": {
|
|
||||||
"similarity": {
|
|
||||||
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
||||||
{
|
{
|
||||||
"$match": {
|
"$match": {
|
||||||
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
||||||
@@ -338,17 +333,17 @@ class PromptBuilder:
|
|||||||
},
|
},
|
||||||
{"$sort": {"similarity": -1}},
|
{"$sort": {"similarity": -1}},
|
||||||
{"$limit": limit},
|
{"$limit": limit},
|
||||||
{"$project": {"content": 1, "similarity": 1}}
|
{"$project": {"content": 1, "similarity": 1}},
|
||||||
]
|
]
|
||||||
|
|
||||||
results = list(db.knowledges.aggregate(pipeline))
|
results = list(db.knowledges.aggregate(pipeline))
|
||||||
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
|
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
# 返回所有找到的内容,用换行分隔
|
# 返回所有找到的内容,用换行分隔
|
||||||
return '\n'.join(str(result['content']) for result in results)
|
return "\n".join(str(result["content"]) for result in results)
|
||||||
|
|
||||||
|
|
||||||
prompt_builder = PromptBuilder()
|
prompt_builder = PromptBuilder()
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class ImageManager:
|
|||||||
self._ensure_description_collection()
|
self._ensure_description_collection()
|
||||||
self._ensure_image_dir()
|
self._ensure_image_dir()
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
|
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=1000)
|
||||||
|
|
||||||
def _ensure_image_dir(self):
|
def _ensure_image_dir(self):
|
||||||
"""确保图像存储目录存在"""
|
"""确保图像存储目录存在"""
|
||||||
|
|||||||
@@ -19,10 +19,11 @@ from ..chat.utils import (
|
|||||||
)
|
)
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
from ..utils.logger_config import setup_logger, LogModule
|
from ..utils.logger_config import LogClassification, LogModule
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
logger = setup_logger(LogModule.MEMORY)
|
log_module = LogModule()
|
||||||
|
logger = log_module.setup_logger(LogClassification.MEMORY)
|
||||||
|
|
||||||
logger.info("初始化记忆系统")
|
logger.info("初始化记忆系统")
|
||||||
|
|
||||||
|
|||||||
@@ -1,71 +1,77 @@
|
|||||||
import sys
|
import sys
|
||||||
from loguru import logger
|
import loguru
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
class LogModule(Enum):
|
class LogClassification(Enum):
|
||||||
BASE = "base"
|
BASE = "base"
|
||||||
MEMORY = "memory"
|
MEMORY = "memory"
|
||||||
EMOJI = "emoji"
|
EMOJI = "emoji"
|
||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
|
|
||||||
def setup_logger(log_type: LogModule = LogModule.BASE):
|
class LogModule:
|
||||||
"""配置日志格式
|
logger = loguru.logger.opt()
|
||||||
|
|
||||||
Args:
|
def __init__(self):
|
||||||
log_type: 日志类型,可选值:BASE(基础日志)、MEMORY(记忆系统日志)、EMOJI(表情包系统日志)
|
pass
|
||||||
"""
|
def setup_logger(self, log_type: LogClassification):
|
||||||
# 移除默认的处理器
|
"""配置日志格式
|
||||||
logger.remove()
|
|
||||||
|
|
||||||
# 基础日志格式
|
Args:
|
||||||
base_format = "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
log_type: 日志类型,可选值:BASE(基础日志)、MEMORY(记忆系统日志)、EMOJI(表情包系统日志)
|
||||||
|
"""
|
||||||
|
# 移除默认日志处理器
|
||||||
|
self.logger.remove()
|
||||||
|
|
||||||
chat_format = "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
# 基础日志格式
|
||||||
|
base_format = "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||||
|
|
||||||
# 记忆系统日志格式
|
chat_format = "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||||
memory_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <light-magenta>海马体</light-magenta> | <level>{message}</level>"
|
|
||||||
|
|
||||||
# 表情包系统日志格式
|
# 记忆系统日志格式
|
||||||
emoji_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>表情包</yellow> | <cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
memory_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <light-magenta>海马体</light-magenta> | <level>{message}</level>"
|
||||||
# 根据日志类型选择日志格式和输出
|
|
||||||
if log_type == LogModule.CHAT:
|
|
||||||
logger.add(
|
|
||||||
sys.stderr,
|
|
||||||
format=chat_format,
|
|
||||||
# level="INFO"
|
|
||||||
)
|
|
||||||
elif log_type == LogModule.MEMORY:
|
|
||||||
# 同时输出到控制台和文件
|
|
||||||
logger.add(
|
|
||||||
sys.stderr,
|
|
||||||
format=memory_format,
|
|
||||||
# level="INFO"
|
|
||||||
)
|
|
||||||
logger.add(
|
|
||||||
"logs/memory.log",
|
|
||||||
format=memory_format,
|
|
||||||
level="INFO",
|
|
||||||
rotation="1 day",
|
|
||||||
retention="7 days"
|
|
||||||
)
|
|
||||||
elif log_type == LogModule.EMOJI:
|
|
||||||
logger.add(
|
|
||||||
sys.stderr,
|
|
||||||
format=emoji_format,
|
|
||||||
# level="INFO"
|
|
||||||
)
|
|
||||||
logger.add(
|
|
||||||
"logs/emoji.log",
|
|
||||||
format=emoji_format,
|
|
||||||
level="INFO",
|
|
||||||
rotation="1 day",
|
|
||||||
retention="7 days"
|
|
||||||
)
|
|
||||||
else: # BASE
|
|
||||||
logger.add(
|
|
||||||
sys.stderr,
|
|
||||||
format=base_format,
|
|
||||||
level="INFO"
|
|
||||||
)
|
|
||||||
|
|
||||||
return logger
|
# 表情包系统日志格式
|
||||||
|
emoji_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>表情包</yellow> | <cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
|
||||||
|
# 根据日志类型选择日志格式和输出
|
||||||
|
if log_type == LogClassification.CHAT:
|
||||||
|
self.logger.add(
|
||||||
|
sys.stderr,
|
||||||
|
format=chat_format,
|
||||||
|
# level="INFO"
|
||||||
|
)
|
||||||
|
elif log_type == LogClassification.MEMORY:
|
||||||
|
|
||||||
|
# 同时输出到控制台和文件
|
||||||
|
self.logger.add(
|
||||||
|
sys.stderr,
|
||||||
|
format=memory_format,
|
||||||
|
# level="INFO"
|
||||||
|
)
|
||||||
|
self.logger.add(
|
||||||
|
"logs/memory.log",
|
||||||
|
format=memory_format,
|
||||||
|
level="INFO",
|
||||||
|
rotation="1 day",
|
||||||
|
retention="7 days"
|
||||||
|
)
|
||||||
|
elif log_type == LogClassification.EMOJI:
|
||||||
|
self.logger.add(
|
||||||
|
sys.stderr,
|
||||||
|
format=emoji_format,
|
||||||
|
# level="INFO"
|
||||||
|
)
|
||||||
|
self.logger.add(
|
||||||
|
"logs/emoji.log",
|
||||||
|
format=emoji_format,
|
||||||
|
level="INFO",
|
||||||
|
rotation="1 day",
|
||||||
|
retention="7 days"
|
||||||
|
)
|
||||||
|
else: # BASE
|
||||||
|
self.logger.add(
|
||||||
|
sys.stderr,
|
||||||
|
format=base_format,
|
||||||
|
level="INFO"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.logger
|
||||||
|
|||||||
Reference in New Issue
Block a user