Merge branch 'main-fix' into relationship

This commit is contained in:
meng_xi_pan
2025-03-14 18:52:13 +08:00
11 changed files with 262 additions and 332 deletions

11
bot.py
View File

@@ -10,9 +10,8 @@ import uvicorn
from dotenv import load_dotenv from dotenv import load_dotenv
from nonebot.adapters.onebot.v11 import Adapter from nonebot.adapters.onebot.v11 import Adapter
import platform import platform
from src.plugins.utils.logger_config import setup_logger from src.plugins.utils.logger_config import LogModule, LogClassification
from loguru import logger
# 配置日志格式 # 配置日志格式
@@ -102,7 +101,9 @@ def load_env():
def load_logger(): def load_logger():
setup_logger() global logger # 使得bot.py中其他函数也能调用
log_module = LogModule()
logger = log_module.setup_logger(LogClassification.BASE)
def scan_provider(env_config: dict): def scan_provider(env_config: dict):
@@ -174,8 +175,6 @@ def raw_main():
if platform.system().lower() != "windows": if platform.system().lower() != "windows":
time.tzset() time.tzset()
# 配置日志
load_logger()
easter_egg() easter_egg()
init_config() init_config()
init_env() init_env()
@@ -207,6 +206,8 @@ def raw_main():
if __name__ == "__main__": if __name__ == "__main__":
try: try:
# 配置日志使得主程序直接退出时候也能访问logger
load_logger()
raw_main() raw_main()
app = nonebot.get_asgi() app = nonebot.get_asgi()

View File

@@ -10,7 +10,6 @@ from nonebot.adapters.onebot.v11 import (
PokeNotifyEvent, PokeNotifyEvent,
GroupRecallNoticeEvent, GroupRecallNoticeEvent,
FriendRecallNoticeEvent, FriendRecallNoticeEvent,
) )
from ..memory_system.memory import hippocampus from ..memory_system.memory import hippocampus
@@ -32,10 +31,11 @@ from .utils_image import image_path_to_base64
from .utils_user import get_user_nickname, get_user_cardname, get_groupname from .utils_user import get_user_nickname, get_user_cardname, get_groupname
from .willing_manager import willing_manager # 导入意愿管理器 from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg from .message_base import UserInfo, GroupInfo, Seg
from ..utils.logger_config import setup_logger, LogModule from ..utils.logger_config import LogClassification, LogModule
# 配置日志 # 配置日志
logger = setup_logger(LogModule.CHAT) log_module = LogModule()
logger = log_module.setup_logger(LogClassification.CHAT)
class ChatBot: class ChatBot:
@@ -95,8 +95,6 @@ class ChatBot:
await self.storage.store_recalled_message(event.message_id, time.time(), chat) await self.storage.store_recalled_message(event.message_id, time.time(), chat)
async def handle_message(self, event: MessageEvent, bot: Bot) -> None: async def handle_message(self, event: MessageEvent, bot: Bot) -> None:
"""处理收到的消息""" """处理收到的消息"""
@@ -161,6 +159,7 @@ class ChatBot:
reply_message=event.reply, reply_message=event.reply,
platform="qq", platform="qq",
) )
await message_cq.initialize()
message_json = message_cq.to_dict() message_json = message_cq.to_dict()
# 进入maimbot # 进入maimbot
@@ -387,6 +386,7 @@ class ChatBot:
reply_message=None, reply_message=None,
platform="qq", platform="qq",
) )
await message_cq.initialize()
message_json = message_cq.to_dict() message_json = message_cq.to_dict()
message = MessageRecv(message_json) message = MessageRecv(message_json)

View File

@@ -1,48 +1,28 @@
import base64 import base64
import html import html
import time import time
import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import ssl
import os import os
import aiohttp
import requests
# 解析各种CQ码
# 包含CQ码类
import urllib3
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from urllib3.util import create_urllib3_context
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .mapper import emojimapper from .mapper import emojimapper
from .message_base import Seg from .message_base import Seg
from .utils_user import get_user_nickname,get_groupname from .utils_user import get_user_nickname, get_groupname
from .message_base import GroupInfo, UserInfo from .message_base import GroupInfo, UserInfo
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
# TLS1.3特殊处理 https://github.com/psf/requests/issues/6616 # 创建SSL上下文
ctx = create_urllib3_context() ssl_context = ssl.create_default_context()
ctx.load_default_certs() ssl_context.set_ciphers("AES128-GCM-SHA256")
ctx.set_ciphers("AES128-GCM-SHA256")
class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def __init__(self, ssl_context=None, **kwargs):
self.ssl_context = ssl_context
super().__init__(**kwargs)
def init_poolmanager(self, connections, maxsize, block=False):
self.poolmanager = urllib3.poolmanager.PoolManager(
num_pools=connections,
maxsize=maxsize,
block=block,
ssl_context=self.ssl_context,
)
@dataclass @dataclass
@@ -70,14 +50,12 @@ class CQCode:
"""初始化LLM实例""" """初始化LLM实例"""
pass pass
def translate(self): async def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象""" """根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == "text": if self.type == "text":
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data=self.params.get("text", ""))
type="text", data=self.params.get("text", "")
)
elif self.type == "image": elif self.type == "image":
base64_data = self.translate_image() base64_data = await self.translate_image()
if base64_data: if base64_data:
if self.params.get("sub_type") == "0": if self.params.get("sub_type") == "0":
self.translated_segments = Seg(type="image", data=base64_data) self.translated_segments = Seg(type="image", data=base64_data)
@@ -90,22 +68,18 @@ class CQCode:
self.translated_segments = Seg(type="text", data="@[全体成员]") self.translated_segments = Seg(type="text", data="@[全体成员]")
else: else:
user_nickname = get_user_nickname(self.params.get("qq", "")) user_nickname = get_user_nickname(self.params.get("qq", ""))
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data=f"[@{user_nickname or '某人'}]")
type="text", data=f"[@{user_nickname or '某人'}]"
)
elif self.type == "reply": elif self.type == "reply":
reply_segments = self.translate_reply() reply_segments = await self.translate_reply()
if reply_segments: if reply_segments:
self.translated_segments = Seg(type="seglist", data=reply_segments) self.translated_segments = Seg(type="seglist", data=reply_segments)
else: else:
self.translated_segments = Seg(type="text", data="[回复某人消息]") self.translated_segments = Seg(type="text", data="[回复某人消息]")
elif self.type == "face": elif self.type == "face":
face_id = self.params.get("id", "") face_id = self.params.get("id", "")
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]")
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
)
elif self.type == "forward": elif self.type == "forward":
forward_segments = self.translate_forward() forward_segments = await self.translate_forward()
if forward_segments: if forward_segments:
self.translated_segments = Seg(type="seglist", data=forward_segments) self.translated_segments = Seg(type="seglist", data=forward_segments)
else: else:
@@ -113,18 +87,8 @@ class CQCode:
else: else:
self.translated_segments = Seg(type="text", data=f"[{self.type}]") self.translated_segments = Seg(type="text", data=f"[{self.type}]")
def get_img(self): async def get_img(self) -> Optional[str]:
""" """异步获取图片并转换为base64"""
headers = {
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
'Accept': 'image/*;q=0.8',
'Accept-Encoding': 'gzip, deflate, br',
'Connection': 'keep-alive',
'Cache-Control': 'no-cache',
'Pragma': 'no-cache'
}
"""
# 腾讯专用请求头配置
headers = { headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36", "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
"Accept": "text/html, application/xhtml xml, */*", "Accept": "text/html, application/xhtml xml, */*",
@@ -133,61 +97,63 @@ class CQCode:
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
"Cache-Control": "no-cache", "Cache-Control": "no-cache",
} }
url = html.unescape(self.params["url"]) url = html.unescape(self.params["url"])
if not url.startswith(("http://", "https://")): if not url.startswith(("http://", "https://")):
return None return None
# 创建专用会话
session = requests.session()
session.adapters.pop("https://", None)
session.mount("https://", TencentSSLAdapter(ctx))
max_retries = 3 max_retries = 3
for retry in range(max_retries): for retry in range(max_retries):
try: try:
response = session.get( logger.debug(f"获取图片中: {url}")
url, # 设置SSL上下文和创建连接器
headers=headers, conn = aiohttp.TCPConnector(ssl=ssl_context)
timeout=15, async with aiohttp.ClientSession(connector=conn) as session:
allow_redirects=True, async with session.get(
stream=True, # 流式传输避免大内存问题 url,
) headers=headers,
timeout=aiohttp.ClientTimeout(total=15),
allow_redirects=True,
) as response:
# 腾讯服务器特殊状态码处理
if response.status == 400 and "multimedia.nt.qq.com.cn" in url:
return None
# 腾讯服务器特殊状态码处理 if response.status != 200:
if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url: raise aiohttp.ClientError(f"HTTP {response.status}")
return None
if response.status_code != 200: # 验证内容类型
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") content_type = response.headers.get("Content-Type", "")
if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}")
# 验证内容类型 # 读取响应内容
content_type = response.headers.get("Content-Type", "") content = await response.read()
if not content_type.startswith("image/"): logger.debug(f"获取图片成功: {url}")
raise ValueError(f"非图片内容类型: {content_type}")
# 转换为Base64 # 转换为Base64
image_base64 = base64.b64encode(response.content).decode("utf-8") image_base64 = base64.b64encode(content).decode("utf-8")
self.image_base64 = image_base64 self.image_base64 = image_base64
return image_base64 return image_base64
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e: except (aiohttp.ClientError, ValueError) as e:
if retry == max_retries - 1: if retry == max_retries - 1:
logger.error(f"最终请求失败: {str(e)}") logger.error(f"最终请求失败: {str(e)}")
time.sleep(1.5**retry) # 指数退避 await asyncio.sleep(1.5**retry) # 指数退避
except Exception: except Exception as e:
logger.exception("[未知错误]") logger.exception(f"获取图片时发生未知错误: {str(e)}")
return None return None
return None return None
def translate_image(self) -> Optional[str]: async def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串""" """处理图片类型的CQ码返回base64字符串"""
if "url" not in self.params: if "url" not in self.params:
return None return None
return self.get_img() return await self.get_img()
def translate_forward(self) -> Optional[List[Seg]]: async def translate_forward(self) -> Optional[List[Seg]]:
"""处理转发消息返回Seg列表""" """处理转发消息返回Seg列表"""
try: try:
if "content" not in self.params: if "content" not in self.params:
@@ -217,15 +183,16 @@ class CQCode:
else: else:
if raw_message: if raw_message:
from .message_cq import MessageRecvCQ from .message_cq import MessageRecvCQ
user_info=UserInfo(
platform='qq', user_info = UserInfo(
platform="qq",
user_id=msg.get("user_id", 0), user_id=msg.get("user_id", 0),
user_nickname=nickname, user_nickname=nickname,
) )
group_info=GroupInfo( group_info = GroupInfo(
platform='qq', platform="qq",
group_id=msg.get("group_id", 0), group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0)) group_name=get_groupname(msg.get("group_id", 0)),
) )
message_obj = MessageRecvCQ( message_obj = MessageRecvCQ(
@@ -235,24 +202,23 @@ class CQCode:
plain_text=raw_message, plain_text=raw_message,
group_info=group_info, group_info=group_info,
) )
content_seg = Seg( await message_obj.initialize()
type="seglist", data=[message_obj.message_segment] content_seg = Seg(type="seglist", data=[message_obj.message_segment])
)
else: else:
content_seg = Seg(type="text", data="[空消息]") content_seg = Seg(type="text", data="[空消息]")
else: else:
if raw_message: if raw_message:
from .message_cq import MessageRecvCQ from .message_cq import MessageRecvCQ
user_info=UserInfo( user_info = UserInfo(
platform='qq', platform="qq",
user_id=msg.get("user_id", 0), user_id=msg.get("user_id", 0),
user_nickname=nickname, user_nickname=nickname,
) )
group_info=GroupInfo( group_info = GroupInfo(
platform='qq', platform="qq",
group_id=msg.get("group_id", 0), group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0)) group_name=get_groupname(msg.get("group_id", 0)),
) )
message_obj = MessageRecvCQ( message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0), message_id=msg.get("message_id", 0),
@@ -261,9 +227,8 @@ class CQCode:
plain_text=raw_message, plain_text=raw_message,
group_info=group_info, group_info=group_info,
) )
content_seg = Seg( await message_obj.initialize()
type="seglist", data=[message_obj.message_segment] content_seg = Seg(type="seglist", data=[message_obj.message_segment])
)
else: else:
content_seg = Seg(type="text", data="[空消息]") content_seg = Seg(type="text", data="[空消息]")
@@ -277,7 +242,7 @@ class CQCode:
logger.error(f"处理转发消息失败: {str(e)}") logger.error(f"处理转发消息失败: {str(e)}")
return None return None
def translate_reply(self) -> Optional[List[Seg]]: async def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码返回Seg列表""" """处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ from .message_cq import MessageRecvCQ
@@ -285,22 +250,19 @@ class CQCode:
return None return None
if self.reply_message.sender.user_id: if self.reply_message.sender.user_id:
message_obj = MessageRecvCQ( message_obj = MessageRecvCQ(
user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.nickname), user_info=UserInfo(
user_id=self.reply_message.sender.user_id, user_nickname=self.reply_message.sender.nickname
),
message_id=self.reply_message.message_id, message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message), raw_message=str(self.reply_message.message),
group_info=GroupInfo(group_id=self.reply_message.group_id), group_info=GroupInfo(group_id=self.reply_message.group_id),
) )
await message_obj.initialize()
segments = [] segments = []
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ: if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
segments.append( segments.append(Seg(type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "))
Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
)
)
else: else:
segments.append( segments.append(
Seg( Seg(
@@ -318,16 +280,12 @@ class CQCode:
@staticmethod @staticmethod
def unescape(text: str) -> str: def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符""" """反转义CQ码中的特殊字符"""
return ( return text.replace(",", ",").replace("[", "[").replace("]", "]").replace("&", "&")
text.replace(",", ",")
.replace("[", "[")
.replace("]", "]")
.replace("&", "&")
)
class CQCode_tool: class CQCode_tool:
@staticmethod @staticmethod
def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode: def cq_from_dict_to_class(cq_code: Dict, msg, reply: Optional[Dict] = None) -> CQCode:
""" """
将CQ码字典转换为CQCode对象 将CQ码字典转换为CQCode对象
@@ -353,11 +311,9 @@ class CQCode_tool:
params=params, params=params,
group_info=msg.message_info.group_info, group_info=msg.message_info.group_info,
user_info=msg.message_info.user_info, user_info=msg.message_info.user_info,
reply_message=reply reply_message=reply,
) )
# 进行翻译处理
instance.translate()
return instance return instance
@staticmethod @staticmethod
@@ -383,12 +339,7 @@ class CQCode_tool:
# 确保使用绝对路径 # 确保使用绝对路径
abs_path = os.path.abspath(file_path) abs_path = os.path.abspath(file_path)
# 转义特殊字符 # 转义特殊字符
escaped_path = ( escaped_path = abs_path.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
abs_path.replace("&", "&")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
)
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@@ -403,10 +354,7 @@ class CQCode_tool:
""" """
# 转义base64数据 # 转义base64数据
escaped_base64 = ( escaped_base64 = (
base64_data.replace("&", "&") base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
) )
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]" return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
@@ -422,10 +370,7 @@ class CQCode_tool:
""" """
# 转义base64数据 # 转义base64数据
escaped_base64 = ( escaped_base64 = (
base64_data.replace("&", "&") base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
) )
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]" return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"

View File

@@ -18,17 +18,17 @@ from ..chat.utils import get_embedding
from ..chat.utils_image import ImageManager, image_path_to_base64 from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.logger_config import setup_logger, LogModule from ..utils.logger_config import LogClassification, LogModule
# 配置日志 # 配置日志
logger = setup_logger(LogModule.EMOJI) log_module = LogModule()
logger = log_module.setup_logger(LogClassification.EMOJI)
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
image_manager = ImageManager() image_manager = ImageManager()
class EmojiManager: class EmojiManager:
_instance = None _instance = None
EMOJI_DIR = os.path.join("data", "emoji") # 表情包存储目录 EMOJI_DIR = os.path.join("data", "emoji") # 表情包存储目录
@@ -43,7 +43,7 @@ class EmojiManager:
self._scan_task = None self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request( self.llm_emotion_judge = LLM_request(
model=global_config.llm_emotion_judge, max_tokens=60, temperature=0.8 model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8
) # 更高的温度更少的token后续可以根据情绪来调整温度 ) # 更高的温度更少的token后续可以根据情绪来调整温度
def _ensure_emoji_dir(self): def _ensure_emoji_dir(self):
@@ -281,7 +281,6 @@ class EmojiManager:
if description is not None: if description is not None:
embedding = await get_embedding(description) embedding = await get_embedding(description)
# 准备数据库记录 # 准备数据库记录
emoji_record = { emoji_record = {
"filename": filename, "filename": filename,

View File

@@ -25,30 +25,19 @@ class ResponseGenerator:
max_tokens=1000, max_tokens=1000,
stream=True, stream=True,
) )
self.model_v3 = LLM_request( self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7, max_tokens=3000)
model=global_config.llm_normal, temperature=0.7, max_tokens=1000 self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=3000)
) self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7, max_tokens=3000)
self.model_r1_distill = LLM_request(
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
)
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.current_model_type = "r1" # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
async def generate_response( async def generate_response(self, message: MessageThinking) -> Optional[Union[str, List[str]]]:
self, message: MessageThinking
) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数""" """根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型 # 从global_config中获取模型概率值并选择模型
rand = random.random() rand = random.random()
if rand < global_config.MODEL_R1_PROBABILITY: if rand < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = "r1" self.current_model_type = "r1"
current_model = self.model_r1 current_model = self.model_r1
elif ( elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
rand
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
):
self.current_model_type = "v3" self.current_model_type = "v3"
current_model = self.model_v3 current_model = self.model_v3
else: else:
@@ -57,24 +46,20 @@ class ResponseGenerator:
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中") logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
model_response = await self._generate_response_with_model( model_response = await self._generate_response_with_model(message, current_model)
message, current_model
)
raw_content = model_response raw_content = model_response
# print(f"raw_content: {raw_content}") # print(f"raw_content: {raw_content}")
# print(f"model_response: {model_response}") # print(f"model_response: {model_response}")
if model_response: if model_response:
logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}') logger.info(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
model_response = await self._process_response(model_response) model_response = await self._process_response(model_response)
if model_response: if model_response:
return model_response, raw_content return model_response, raw_content
return None, raw_content return None, raw_content
async def _generate_response_with_model( async def _generate_response_with_model(self, message: MessageThinking, model: LLM_request) -> Optional[str]:
self, message: MessageThinking, model: LLM_request
) -> Optional[str]:
"""使用指定的模型生成回复""" """使用指定的模型生成回复"""
sender_name = "" sender_name = ""
if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname: if message.chat_stream.user_info.user_cardname and message.chat_stream.user_info.user_nickname:
@@ -229,13 +214,11 @@ class InitiativeMessageGenerate:
def __init__(self): def __init__(self):
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7) self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7) self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
self.model_r1_distill = LLM_request( self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7)
model=global_config.llm_reasoning_minor, temperature=0.7
)
def gen_response(self, message: Message): def gen_response(self, message: Message):
topic_select_prompt, dots_for_select, prompt_template = ( topic_select_prompt, dots_for_select, prompt_template = prompt_builder._build_initiative_prompt_select(
prompt_builder._build_initiative_prompt_select(message.group_id) message.group_id
) )
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt) content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
logger.debug(f"{content_select} {reasoning}") logger.debug(f"{content_select} {reasoning}")
@@ -247,16 +230,12 @@ class InitiativeMessageGenerate:
return None return None
else: else:
return None return None
prompt_check, memory = prompt_builder._build_initiative_prompt_check( prompt_check, memory = prompt_builder._build_initiative_prompt_check(select_dot[1], prompt_template)
select_dot[1], prompt_template
)
content_check, reasoning_check = self.model_v3.generate_response(prompt_check) content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
logger.info(f"{content_check} {reasoning_check}") logger.info(f"{content_check} {reasoning_check}")
if "yes" not in content_check.lower(): if "yes" not in content_check.lower():
return None return None
prompt = prompt_builder._build_initiative_prompt( prompt = prompt_builder._build_initiative_prompt(select_dot, prompt_template, memory)
select_dot, prompt_template, memory
)
content, reasoning = self.model_r1.generate_response_async(prompt) content, reasoning = self.model_r1.generate_response_async(prompt)
logger.debug(f"[DEBUG] {content} {reasoning}") logger.debug(f"[DEBUG] {content} {reasoning}")
return content return content

View File

@@ -329,6 +329,7 @@ class MessageSending(MessageProcessBase):
self.message_segment, self.message_segment,
], ],
) )
return self
async def process(self) -> None: async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本""" """处理消息内容,生成纯文本和详细文本"""

View File

@@ -57,16 +57,20 @@ class MessageRecvCQ(MessageCQ):
# 私聊消息不携带group_info # 私聊消息不携带group_info
if group_info is None: if group_info is None:
pass pass
elif group_info.group_name is None: elif group_info.group_name is None:
group_info.group_name = get_groupname(group_info.group_id) group_info.group_name = get_groupname(group_info.group_id)
# 解析消息段 # 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message) self.message_segment = None # 初始化为None
self.raw_message = raw_message self.raw_message = raw_message
# 异步初始化在外部完成
def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg: async def initialize(self):
"""解析消息内容为Seg对象""" """异步初始化方法"""
self.message_segment = await self._parse_message(self.raw_message)
async def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
"""异步解析消息内容为Seg对象"""
cq_code_dict_list = [] cq_code_dict_list = []
segments = [] segments = []
@@ -98,9 +102,10 @@ class MessageRecvCQ(MessageCQ):
# 转换CQ码为Seg对象 # 转换CQ码为Seg对象
for code_item in cq_code_dict_list: for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message) cq_code_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message)
if message_obj.translated_segments: await cq_code_obj.translate() # 异步调用translate
segments.append(message_obj.translated_segments) if cq_code_obj.translated_segments:
segments.append(cq_code_obj.translated_segments)
# 如果只有一个segment直接返回 # 如果只有一个segment直接返回
if len(segments) == 1: if len(segments) == 1:
@@ -133,9 +138,7 @@ class MessageSendCQ(MessageCQ):
self.message_segment = message_segment self.message_segment = message_segment
self.raw_message = self._generate_raw_message() self.raw_message = self._generate_raw_message()
def _generate_raw_message( def _generate_raw_message(self) -> str:
self,
) -> str:
"""将Seg对象转换为raw_message""" """将Seg对象转换为raw_message"""
segments = [] segments = []

View File

@@ -15,8 +15,8 @@ from .relationship_manager import relationship_manager
class PromptBuilder: class PromptBuilder:
def __init__(self): def __init__(self):
self.prompt_built = '' self.prompt_built = ""
self.activate_messages = '' self.activate_messages = ""
async def _build_prompt(self, async def _build_prompt(self,
chat_stream, chat_stream,
@@ -88,46 +88,43 @@ class PromptBuilder:
current_date = time.strftime("%Y-%m-%d", time.localtime()) current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task() bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' prompt_date = f"""今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n"""
# 知识构建 # 知识构建
start_time = time.time() start_time = time.time()
prompt_info = '' prompt_info = ""
promt_info_prompt = '' promt_info_prompt = ""
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5) prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
if prompt_info: if prompt_info:
prompt_info = f'''你有以下这些[知识]{prompt_info}请你记住上面的[ prompt_info = f"""你有以下这些[知识]{prompt_info}请你记住上面的[
知识],之后可能会用到-''' 知识],之后可能会用到-"""
end_time = time.time() end_time = time.time()
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}") logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}")
# 获取聊天上下文 # 获取聊天上下文
chat_in_group=True chat_in_group = True
chat_talking_prompt = '' chat_talking_prompt = ""
if stream_id: if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_talking_prompt = get_recent_group_detailed_plain_text(
chat_stream=chat_manager.get_stream(stream_id) stream_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True
)
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream.group_info: if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
else: else:
chat_in_group=False chat_in_group = False
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 使用新的记忆获取方法 # 使用新的记忆获取方法
memory_prompt = '' memory_prompt = ""
start_time = time.time() start_time = time.time()
# 调用 hippocampus 的 get_relevant_memories 方法 # 调用 hippocampus 的 get_relevant_memories 方法
relevant_memories = await hippocampus.get_relevant_memories( relevant_memories = await hippocampus.get_relevant_memories(
text=message_txt, text=message_txt, max_topics=5, similarity_threshold=0.4, max_memory_num=5
max_topics=5,
similarity_threshold=0.4,
max_memory_num=5
) )
if relevant_memories: if relevant_memories:
@@ -147,7 +144,7 @@ class PromptBuilder:
logger.info(f"回忆耗时: {(end_time - start_time):.3f}") logger.info(f"回忆耗时: {(end_time - start_time):.3f}")
# 激活prompt构建 # 激活prompt构建
activate_prompt = '' activate_prompt = ""
if chat_in_group: if chat_in_group:
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt}\ activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt}\
{relation_prompt}{relation_prompt2}现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意。请分析聊天记录,根据你和他的关系和态度进行回复,明确你的立场和情感。" {relation_prompt}{relation_prompt2}现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意。请分析聊天记录,根据你和他的关系和态度进行回复,明确你的立场和情感。"
@@ -155,20 +152,22 @@ class PromptBuilder:
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,{relation_prompt}{mood_prompt},你的回复态度是{relation_prompt2}" activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,{relation_prompt}{mood_prompt},你的回复态度是{relation_prompt2}"
# 关键词检测与反应 # 关键词检测与反应
keywords_reaction_prompt = '' keywords_reaction_prompt = ""
for rule in global_config.keywords_reaction_rules: for rule in global_config.keywords_reaction_rules:
if rule.get("enable", False): if rule.get("enable", False):
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])): if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}") logger.info(
keywords_reaction_prompt += rule.get("reaction", "") + '' f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}"
)
keywords_reaction_prompt += rule.get("reaction", "") + ""
#人格选择 # 人格选择
personality=global_config.PROMPT_PERSONALITY personality = global_config.PROMPT_PERSONALITY
probability_1 = global_config.PERSONALITY_1 probability_1 = global_config.PERSONALITY_1
probability_2 = global_config.PERSONALITY_2 probability_2 = global_config.PERSONALITY_2
probability_3 = global_config.PERSONALITY_3 probability_3 = global_config.PERSONALITY_3
prompt_personality = f'{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{"/".join(global_config.BOT_ALIAS_NAMES)}' prompt_personality = f"{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{'/'.join(global_config.BOT_ALIAS_NAMES)}"
personality_choice = random.random() personality_choice = random.random()
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
@@ -185,13 +184,13 @@ class PromptBuilder:
请你表达自己的见解和观点。可以有个性。''' 请你表达自己的见解和观点。可以有个性。'''
# 中文高手(新加的好玩功能) # 中文高手(新加的好玩功能)
prompt_ger = '' prompt_ger = ""
if random.random() < 0.04: if random.random() < 0.04:
prompt_ger += '你喜欢用倒装句' prompt_ger += "你喜欢用倒装句"
if random.random() < 0.02: if random.random() < 0.02:
prompt_ger += '你喜欢用反问句' prompt_ger += "你喜欢用反问句"
if random.random() < 0.01: if random.random() < 0.01:
prompt_ger += '你喜欢用文言文' prompt_ger += "你喜欢用文言文"
# 额外信息要求 # 额外信息要求
extra_info = f'''但是记得你的回复态度和你的立场,切记你回复的人是{sender_name},不要输出你的思考过程,只需要输出最终的回复,务必简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' extra_info = f'''但是记得你的回复态度和你的立场,切记你回复的人是{sender_name},不要输出你的思考过程,只需要输出最终的回复,务必简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
@@ -225,38 +224,38 @@ class PromptBuilder:
current_date = time.strftime("%Y-%m-%d", time.localtime()) current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task() bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' prompt_date = f"""今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n"""
chat_talking_prompt = '' chat_talking_prompt = ""
if group_id: if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(group_id, chat_talking_prompt = get_recent_group_detailed_plain_text(
limit=global_config.MAX_CONTEXT_SIZE, group_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True
combine=True) )
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 获取主动发言的话题 # 获取主动发言的话题
all_nodes = memory_graph.dots all_nodes = memory_graph.dots
all_nodes = filter(lambda dot: len(dot[1]['memory_items']) > 3, all_nodes) all_nodes = filter(lambda dot: len(dot[1]["memory_items"]) > 3, all_nodes)
nodes_for_select = random.sample(all_nodes, 5) nodes_for_select = random.sample(all_nodes, 5)
topics = [info[0] for info in nodes_for_select] topics = [info[0] for info in nodes_for_select]
infos = [info[1] for info in nodes_for_select] infos = [info[1] for info in nodes_for_select]
# 激活prompt构建 # 激活prompt构建
activate_prompt = '' activate_prompt = ""
activate_prompt = "以上是群里正在进行的聊天。" activate_prompt = "以上是群里正在进行的聊天。"
personality = global_config.PROMPT_PERSONALITY personality = global_config.PROMPT_PERSONALITY
prompt_personality = '' prompt_personality = ""
personality_choice = random.random() personality_choice = random.random()
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}''' prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}"""
elif personality_choice < probability_1 + probability_2: # 第二种人格 elif personality_choice < probability_1 + probability_2: # 第二种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}''' prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}"""
else: # 第三种人格 else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}''' prompt_personality = f"""{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}"""
topics_str = ','.join(f"\"{topics}\"") topics_str = ",".join(f'"{topics}"')
prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)" prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}" prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
@@ -265,8 +264,8 @@ class PromptBuilder:
return prompt_initiative_select, nodes_for_select, prompt_regular return prompt_initiative_select, nodes_for_select, prompt_regular
def _build_initiative_prompt_check(self, selected_node, prompt_regular): def _build_initiative_prompt_check(self, selected_node, prompt_regular):
memory = random.sample(selected_node['memory_items'], 3) memory = random.sample(selected_node["memory_items"], 3)
memory = '\n'.join(memory) memory = "\n".join(memory)
prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。" prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check, memory return prompt_for_check, memory
@@ -275,7 +274,7 @@ class PromptBuilder:
return prompt_for_initiative return prompt_for_initiative
async def get_prompt_info(self, message: str, threshold: float): async def get_prompt_info(self, message: str, threshold: float):
related_info = '' related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message) embedding = await get_embedding(message)
related_info += self.get_info_from_db(embedding, threshold=threshold) related_info += self.get_info_from_db(embedding, threshold=threshold)
@@ -284,7 +283,7 @@ class PromptBuilder:
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str: def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
if not query_embedding: if not query_embedding:
return '' return ""
# 使用余弦相似度计算 # 使用余弦相似度计算
pipeline = [ pipeline = [
{ {
@@ -296,12 +295,14 @@ class PromptBuilder:
"in": { "in": {
"$add": [ "$add": [
"$$value", "$$value",
{"$multiply": [ {
{"$arrayElemAt": ["$embedding", "$$this"]}, "$multiply": [
{"$arrayElemAt": [query_embedding, "$$this"]} {"$arrayElemAt": ["$embedding", "$$this"]},
]} {"$arrayElemAt": [query_embedding, "$$this"]},
]
},
] ]
} },
} }
}, },
"magnitude1": { "magnitude1": {
@@ -309,7 +310,7 @@ class PromptBuilder:
"$reduce": { "$reduce": {
"input": "$embedding", "input": "$embedding",
"initialValue": 0, "initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
} }
} }
}, },
@@ -318,19 +319,13 @@ class PromptBuilder:
"$reduce": { "$reduce": {
"input": query_embedding, "input": query_embedding,
"initialValue": 0, "initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
} }
} }
} },
}
},
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
} }
}, },
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{ {
"$match": { "$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
@@ -338,17 +333,17 @@ class PromptBuilder:
}, },
{"$sort": {"similarity": -1}}, {"$sort": {"similarity": -1}},
{"$limit": limit}, {"$limit": limit},
{"$project": {"content": 1, "similarity": 1}} {"$project": {"content": 1, "similarity": 1}},
] ]
results = list(db.knowledges.aggregate(pipeline)) results = list(db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results: if not results:
return '' return ""
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return '\n'.join(str(result['content']) for result in results) return "\n".join(str(result["content"]) for result in results)
prompt_builder = PromptBuilder() prompt_builder = PromptBuilder()

View File

@@ -34,7 +34,7 @@ class ImageManager:
self._ensure_description_collection() self._ensure_description_collection()
self._ensure_image_dir() self._ensure_image_dir()
self._initialized = True self._initialized = True
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=1000)
def _ensure_image_dir(self): def _ensure_image_dir(self):
"""确保图像存储目录存在""" """确保图像存储目录存在"""

View File

@@ -19,10 +19,11 @@ from ..chat.utils import (
) )
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.logger_config import setup_logger, LogModule from ..utils.logger_config import LogClassification, LogModule
# 配置日志 # 配置日志
logger = setup_logger(LogModule.MEMORY) log_module = LogModule()
logger = log_module.setup_logger(LogClassification.MEMORY)
logger.info("初始化记忆系统") logger.info("初始化记忆系统")

View File

@@ -1,71 +1,77 @@
import sys import sys
from loguru import logger import loguru
from enum import Enum from enum import Enum
class LogModule(Enum): class LogClassification(Enum):
BASE = "base" BASE = "base"
MEMORY = "memory" MEMORY = "memory"
EMOJI = "emoji" EMOJI = "emoji"
CHAT = "chat" CHAT = "chat"
def setup_logger(log_type: LogModule = LogModule.BASE): class LogModule:
"""配置日志格式 logger = loguru.logger.opt()
Args: def __init__(self):
log_type: 日志类型可选值BASE(基础日志)、MEMORY(记忆系统日志)、EMOJI(表情包系统日志) pass
""" def setup_logger(self, log_type: LogClassification):
# 移除默认的处理器 """配置日志格式
logger.remove()
# 基础日志格式 Args:
base_format = "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" log_type: 日志类型可选值BASE(基础日志)、MEMORY(记忆系统日志)、EMOJI(表情包系统日志)
"""
# 移除默认日志处理器
self.logger.remove()
chat_format = "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" # 基础日志格式
base_format = "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
# 记忆系统日志格式 chat_format = "<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
memory_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <light-magenta>海马体</light-magenta> | <level>{message}</level>"
# 表情包系统日志格式 # 记忆系统日志格式
emoji_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>表情包</yellow> | <cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" memory_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <light-magenta>海马体</light-magenta> | <level>{message}</level>"
# 根据日志类型选择日志格式和输出
if log_type == LogModule.CHAT:
logger.add(
sys.stderr,
format=chat_format,
# level="INFO"
)
elif log_type == LogModule.MEMORY:
# 同时输出到控制台和文件
logger.add(
sys.stderr,
format=memory_format,
# level="INFO"
)
logger.add(
"logs/memory.log",
format=memory_format,
level="INFO",
rotation="1 day",
retention="7 days"
)
elif log_type == LogModule.EMOJI:
logger.add(
sys.stderr,
format=emoji_format,
# level="INFO"
)
logger.add(
"logs/emoji.log",
format=emoji_format,
level="INFO",
rotation="1 day",
retention="7 days"
)
else: # BASE
logger.add(
sys.stderr,
format=base_format,
level="INFO"
)
return logger # 表情包系统日志格式
emoji_format = "<green>{time:HH:mm}</green> | <level>{level: <8}</level> | <yellow>表情包</yellow> | <cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
# 根据日志类型选择日志格式和输出
if log_type == LogClassification.CHAT:
self.logger.add(
sys.stderr,
format=chat_format,
# level="INFO"
)
elif log_type == LogClassification.MEMORY:
# 同时输出到控制台和文件
self.logger.add(
sys.stderr,
format=memory_format,
# level="INFO"
)
self.logger.add(
"logs/memory.log",
format=memory_format,
level="INFO",
rotation="1 day",
retention="7 days"
)
elif log_type == LogClassification.EMOJI:
self.logger.add(
sys.stderr,
format=emoji_format,
# level="INFO"
)
self.logger.add(
"logs/emoji.log",
format=emoji_format,
level="INFO",
rotation="1 day",
retention="7 days"
)
else: # BASE
self.logger.add(
sys.stderr,
format=base_format,
level="INFO"
)
return self.logger