feat: 重构完成开始测试debug

This commit is contained in:
tcmofashi
2025-03-11 01:15:32 +08:00
parent 20b8778e2b
commit 7899e67cb2
13 changed files with 486 additions and 572 deletions

View File

@@ -1,6 +1,5 @@
import time import time
from random import random from random import random
from loguru import logger from loguru import logger
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
@@ -11,25 +10,18 @@ from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器 from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator from .llm_generator import ResponseGenerator
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message_cq import ( from .message_cq import (
MessageRecvCQ, MessageRecvCQ,
MessageSendCQ,
)
from .chat_stream import chat_manager
MessageRecvCQ,
MessageSendCQ,
) )
from .chat_stream import chat_manager from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器 from .message_sender import message_manager # 导入新的消息管理器
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .storage import MessageStorage from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_txt from .utils import calculate_typing_time, is_mentioned_bot_in_txt
from .utils_image import image_path_to_base64 from .utils_image import image_path_to_base64
from .utils_image import image_path_to_base64
from .willing_manager import willing_manager # 导入意愿管理器 from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg from .message_base import UserInfo, GroupInfo, Seg
from .message_base import UserInfo, GroupInfo, Seg
class ChatBot: class ChatBot:
def __init__(self): def __init__(self):
@@ -53,24 +45,21 @@ class ChatBot:
self.bot = bot # 更新 bot 实例 self.bot = bot # 更新 bot 实例
group_info = await bot.get_group_info(group_id=event.group_id) group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info) # 白名单设定由nontbot侧完成
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5) if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
if event.user_id in global_config.ban_user_id:
return
message_cq=MessageRecvCQ(
message_cq=MessageRecvCQ( message_cq=MessageRecvCQ(
message_id=event.message_id, message_id=event.message_id,
user_id=event.user_id, user_id=event.user_id,
raw_message=str(event.original_message), raw_message=str(event.original_message),
group_id=event.group_id, group_id=event.group_id,
user_id=event.user_id,
raw_message=str(event.original_message),
group_id=event.group_id,
reply_message=event.reply, reply_message=event.reply,
platform='qq' platform='qq'
) )
@@ -78,37 +67,26 @@ class ChatBot:
# 进入maimbot # 进入maimbot
message=MessageRecv(**message_json) message=MessageRecv(**message_json)
await message.process()
groupinfo=message.message_info.group_info groupinfo=message.message_info.group_info
userinfo=message.message_info.user_info userinfo=message.message_info.user_info
messageinfo=message.message_info messageinfo=message.message_info
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
# 消息过滤涉及到config有待更新 # 消息过滤涉及到config有待更新
if groupinfo:
if groupinfo.group_id not in global_config.talk_allowed_groups: chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
return await relationship_manager.update_relationship(chat_stream=chat,)
else: await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value = 0.5)
if userinfo:
if userinfo.user_id in []: await message.process()
pass
else:
return
else:
return
if userinfo.user_id in global_config.ban_user_id:
return
# 过滤词 # 过滤词
for word in global_config.ban_words: for word in global_config.ban_words:
if word in message.processed_plain_text:
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
if word in message.processed_plain_text: if word in message.processed_plain_text:
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}") logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered") logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered")
return return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time)) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
@@ -130,20 +108,13 @@ class ChatBot:
is_emoji=message.is_emoji, is_emoji=message.is_emoji,
interested_rate=interested_rate interested_rate=interested_rate
) )
current_willing = willing_manager.get_willing( current_willing = willing_manager.get_willing(chat_stream=chat)
chat_stream=chat
)
print(f"\033[1;32m[{current_time}][{chat.group_info.group_name}]{chat.user_info.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m") print(f"\033[1;32m[{current_time}][{chat.group_info.group_name}]{chat.user_info.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
response = None response = None
if random() < reply_probability: if random() < reply_probability:
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
platform=messageinfo.platform
)
bot_user_info=UserInfo( bot_user_info=UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.BOT_NICKNAME,
@@ -151,22 +122,16 @@ class ChatBot:
) )
tinking_time_point = round(time.time(), 2) tinking_time_point = round(time.time(), 2)
think_id = 'mt' + str(tinking_time_point) think_id = 'mt' + str(tinking_time_point)
thinking_message = MessageThinking.from_chat_stream( thinking_message = MessageThinking(
chat_stream=chat,
message_id=think_id, message_id=think_id,
reply=message
)
thinking_message = MessageThinking.from_chat_stream(
chat_stream=chat, chat_stream=chat,
message_id=think_id, bot_user_info=bot_user_info,
reply=message reply=message
) )
message_manager.add_message(thinking_message) message_manager.add_message(thinking_message)
willing_manager.change_reply_willing_sent( willing_manager.change_reply_willing_sent(chat)
chat_stream=chat
)
response,raw_content = await self.gpt.generate_response(message) response,raw_content = await self.gpt.generate_response(message)
@@ -201,18 +166,11 @@ class ChatBot:
accu_typing_time += typing_time accu_typing_time += typing_time
timepoint = tinking_time_point + accu_typing_time timepoint = tinking_time_point + accu_typing_time
message_segment = Seg(type='text', data=msg)
bot_message = MessageSending(
message_segment = Seg(type='text', data=msg) message_segment = Seg(type='text', data=msg)
bot_message = MessageSending( bot_message = MessageSending(
message_id=think_id, message_id=think_id,
chat_stream=chat, chat_stream=chat,
message_segment=message_segment, bot_user_info=bot_user_info,
reply=message,
is_head=not mark_head,
is_emoji=False
)
chat_stream=chat,
message_segment=message_segment, message_segment=message_segment,
reply=message, reply=message,
is_head=not mark_head, is_head=not mark_head,
@@ -235,7 +193,6 @@ class ChatBot:
if emoji_raw != None: if emoji_raw != None:
emoji_path,discription = emoji_raw emoji_path,discription = emoji_raw
emoji_cq = image_path_to_base64(emoji_path)
emoji_cq = image_path_to_base64(emoji_path) emoji_cq = image_path_to_base64(emoji_path)
if random() < 0.5: if random() < 0.5:
@@ -247,15 +204,7 @@ class ChatBot:
bot_message = MessageSending( bot_message = MessageSending(
message_id=think_id, message_id=think_id,
chat_stream=chat, chat_stream=chat,
message_segment=message_segment, bot_user_info=bot_user_info,
reply=message,
is_head=False,
is_emoji=True
)
message_segment = Seg(type='emoji', data=emoji_cq)
bot_message = MessageSending(
message_id=think_id,
chat_stream=chat,
message_segment=message_segment, message_segment=message_segment,
reply=message, reply=message,
is_head=False, is_head=False,
@@ -273,20 +222,12 @@ class ChatBot:
'fearful': -0.7, 'fearful': -0.7,
'neutral': 0.1 'neutral': 0.1
} }
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]]) await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value=valuedict[emotion[0]])
# 使用情绪管理器更新情绪 # 使用情绪管理器更新情绪
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor) self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
willing_manager.change_reply_willing_after_sent( willing_manager.change_reply_willing_after_sent(
platform=messageinfo.platform, chat_stream=chat
user_info=userinfo,
group_info=groupinfo
)
willing_manager.change_reply_willing_after_sent(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo
) )
# 创建全局ChatBot实例 # 创建全局ChatBot实例

View File

@@ -1,53 +1,65 @@
import time
import asyncio import asyncio
from typing import Optional, Dict, Tuple
import hashlib import hashlib
import time
from typing import Dict, Optional
from loguru import logger from loguru import logger
from ...common.database import Database from ...common.database import Database
from .message_base import UserInfo, GroupInfo from .message_base import GroupInfo, UserInfo
class ChatStream: class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文""" """聊天流对象,存储一个完整的聊天上下文"""
def __init__(self,
stream_id: str, def __init__(
platform: str, self,
user_info: UserInfo, stream_id: str,
group_info: Optional[GroupInfo] = None, platform: str,
data: dict = None): user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
):
self.stream_id = stream_id self.stream_id = stream_id
self.platform = platform self.platform = platform
self.user_info = user_info self.user_info = user_info
self.group_info = group_info self.group_info = group_info
self.create_time = data.get('create_time', int(time.time())) if data else int(time.time()) self.create_time = (
self.last_active_time = data.get('last_active_time', self.create_time) if data else self.create_time data.get("create_time", int(time.time())) if data else int(time.time())
)
self.last_active_time = (
data.get("last_active_time", self.create_time) if data else self.create_time
)
self.saved = False self.saved = False
def to_dict(self) -> dict: def to_dict(self) -> dict:
"""转换为字典格式""" """转换为字典格式"""
result = { result = {
'stream_id': self.stream_id, "stream_id": self.stream_id,
'platform': self.platform, "platform": self.platform,
'user_info': self.user_info.to_dict() if self.user_info else None, "user_info": self.user_info.to_dict() if self.user_info else None,
'group_info': self.group_info.to_dict() if self.group_info else None, "group_info": self.group_info.to_dict() if self.group_info else None,
'create_time': self.create_time, "create_time": self.create_time,
'last_active_time': self.last_active_time "last_active_time": self.last_active_time,
} }
return result return result
@classmethod @classmethod
def from_dict(cls, data: dict) -> 'ChatStream': def from_dict(cls, data: dict) -> "ChatStream":
"""从字典创建实例""" """从字典创建实例"""
user_info = UserInfo(**data.get('user_info', {})) if data.get('user_info') else None user_info = (
group_info = GroupInfo(**data.get('group_info', {})) if data.get('group_info') else None UserInfo(**data.get("user_info", {})) if data.get("user_info") else None
)
group_info = (
GroupInfo(**data.get("group_info", {})) if data.get("group_info") else None
)
return cls( return cls(
stream_id=data['stream_id'], stream_id=data["stream_id"],
platform=data['platform'], platform=data["platform"],
user_info=user_info, user_info=user_info,
group_info=group_info, group_info=group_info,
data=data data=data,
) )
def update_active_time(self): def update_active_time(self):
@@ -58,6 +70,7 @@ class ChatStream:
class ChatManager: class ChatManager:
"""聊天管理器,管理所有聊天流""" """聊天管理器,管理所有聊天流"""
_instance = None _instance = None
_initialized = False _initialized = False
@@ -97,33 +110,32 @@ class ChatManager:
def _ensure_collection(self): def _ensure_collection(self):
"""确保数据库集合存在并创建索引""" """确保数据库集合存在并创建索引"""
if 'chat_streams' not in self.db.db.list_collection_names(): if "chat_streams" not in self.db.db.list_collection_names():
self.db.db.create_collection('chat_streams') self.db.db.create_collection("chat_streams")
# 创建索引 # 创建索引
self.db.db.chat_streams.create_index([('stream_id', 1)], unique=True) self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index([ self.db.db.chat_streams.create_index(
('platform', 1), [("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
('user_info.user_id', 1), )
('group_info.group_id', 1)
])
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str: def _generate_stream_id(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> str:
"""生成聊天流唯一ID""" """生成聊天流唯一ID"""
# 组合关键信息 # 组合关键信息
components = [ components = [
platform, platform,
str(user_info.user_id), str(user_info.user_id),
str(group_info.group_id) if group_info else 'private' str(group_info.group_id) if group_info else "private",
] ]
# 使用MD5生成唯一ID # 使用MD5生成唯一ID
key = '_'.join(components) key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(self, async def get_or_create_stream(
platform: str, self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
user_info: UserInfo, ) -> ChatStream:
group_info: Optional[GroupInfo] = None) -> ChatStream:
"""获取或创建聊天流 """获取或创建聊天流
Args: Args:
@@ -148,7 +160,7 @@ class ChatManager:
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({'stream_id': stream_id}) data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
if data: if data:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息 # 更新用户信息和群组信息
@@ -162,7 +174,7 @@ class ChatManager:
stream_id=stream_id, stream_id=stream_id,
platform=platform, platform=platform,
user_info=user_info, user_info=user_info,
group_info=group_info group_info=group_info,
) )
# 保存到内存和数据库 # 保存到内存和数据库
@@ -174,10 +186,9 @@ class ChatManager:
"""通过stream_id获取聊天流""" """通过stream_id获取聊天流"""
return self.streams.get(stream_id) return self.streams.get(stream_id)
def get_stream_by_info(self, def get_stream_by_info(
platform: str, self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
user_info: UserInfo, ) -> Optional[ChatStream]:
group_info: Optional[GroupInfo] = None) -> Optional[ChatStream]:
"""通过信息获取聊天流""" """通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info) stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id) return self.streams.get(stream_id)
@@ -186,9 +197,7 @@ class ChatManager:
"""保存聊天流到数据库""" """保存聊天流到数据库"""
if not stream.saved: if not stream.saved:
self.db.db.chat_streams.update_one( self.db.db.chat_streams.update_one(
{'stream_id': stream.stream_id}, {"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
{'$set': stream.to_dict()},
upsert=True
) )
stream.saved = True stream.saved = True

View File

@@ -3,23 +3,22 @@ import html
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, List, Union from typing import Dict, List, Optional, Union
import requests import requests
# 解析各种CQ码 # 解析各种CQ码
# 包含CQ码类 # 包含CQ码类
import urllib3 import urllib3
from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from urllib3.util import create_urllib3_context from urllib3.util import create_urllib3_context
from loguru import logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .mapper import emojimapper from .mapper import emojimapper
from .utils_image import image_manager
from .utils_user import get_user_nickname
from .message_base import Seg from .message_base import Seg
from .utils_user import get_user_nickname
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
@@ -37,8 +36,11 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False): def init_poolmanager(self, connections, maxsize, block=False):
self.poolmanager = urllib3.poolmanager.PoolManager( self.poolmanager = urllib3.poolmanager.PoolManager(
num_pools=connections, maxsize=maxsize, num_pools=connections,
block=block, ssl_context=self.ssl_context) maxsize=maxsize,
block=block,
ssl_context=self.ssl_context,
)
@dataclass @dataclass
@@ -52,6 +54,7 @@ class CQCode:
raw_code: 原始CQ码字符串 raw_code: 原始CQ码字符串
translated_segments: 经过处理后的Seg对象列表 translated_segments: 经过处理后的Seg对象列表
""" """
type: str type: str
params: Dict[str, str] params: Dict[str, str]
group_id: int group_id: int
@@ -65,77 +68,52 @@ class CQCode:
def __post_init__(self): def __post_init__(self):
"""初始化LLM实例""" """初始化LLM实例"""
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) self._llm = LLM_request(
model=global_config.vlm, temperature=0.4, max_tokens=300
)
def translate(self): def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象""" """根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == 'text': if self.type == "text":
self.translated_segments = Seg( self.translated_segments = Seg(
type='text', type="text", data=self.params.get("text", "")
data=self.params.get('text', '')
) )
elif self.type == 'image': elif self.type == "image":
base64_data = self.translate_image() base64_data = self.translate_image()
if base64_data: if base64_data:
if self.params.get('sub_type') == '0': if self.params.get("sub_type") == "0":
self.translated_segments = Seg( self.translated_segments = Seg(type="image", data=base64_data)
type='image',
data=base64_data
)
else: else:
self.translated_segments = Seg( self.translated_segments = Seg(type="emoji", data=base64_data)
type='emoji',
data=base64_data
)
else: else:
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data="[图片]")
type='text', elif self.type == "at":
data='[图片]' user_nickname = get_user_nickname(self.params.get("qq", ""))
)
elif self.type == 'at':
user_nickname = get_user_nickname(self.params.get('qq', ''))
self.translated_segments = Seg( self.translated_segments = Seg(
type='text', type="text", data=f"[@{user_nickname or '某人'}]"
data=f"[@{user_nickname or '某人'}]"
) )
elif self.type == 'reply': elif self.type == "reply":
reply_segments = self.translate_reply() reply_segments = self.translate_reply()
if reply_segments: if reply_segments:
self.translated_segments = Seg( self.translated_segments = Seg(type="seglist", data=reply_segments)
type='seglist',
data=reply_segments
)
else: else:
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data="[回复某人消息]")
type='text', elif self.type == "face":
data='[回复某人消息]' face_id = self.params.get("id", "")
)
elif self.type == 'face':
face_id = self.params.get('id', '')
self.translated_segments = Seg( self.translated_segments = Seg(
type='text', type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
data=f"[{emojimapper.get(int(face_id), '表情')}]"
) )
elif self.type == 'forward': elif self.type == "forward":
forward_segments = self.translate_forward() forward_segments = self.translate_forward()
if forward_segments: if forward_segments:
self.translated_segments = Seg( self.translated_segments = Seg(type="seglist", data=forward_segments)
type='seglist',
data=forward_segments
)
else: else:
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data="[转发消息]")
type='text',
data='[转发消息]'
)
else: else:
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data=f"[{self.type}]")
type='text',
data=f"[{self.type}]"
)
def get_img(self): def get_img(self):
''' """
headers = { headers = {
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0', 'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
'Accept': 'image/*;q=0.8', 'Accept': 'image/*;q=0.8',
@@ -144,18 +122,18 @@ class CQCode:
'Cache-Control': 'no-cache', 'Cache-Control': 'no-cache',
'Pragma': 'no-cache' 'Pragma': 'no-cache'
} }
''' """
# 腾讯专用请求头配置 # 腾讯专用请求头配置
headers = { headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36', "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
'Accept': 'text/html, application/xhtml xml, */*', "Accept": "text/html, application/xhtml xml, */*",
'Accept-Encoding': 'gbk, GB2312', "Accept-Encoding": "gbk, GB2312",
'Accept-Language': 'zh-cn', "Accept-Language": "zh-cn",
'Content-Type': 'application/x-www-form-urlencoded', "Content-Type": "application/x-www-form-urlencoded",
'Cache-Control': 'no-cache' "Cache-Control": "no-cache",
} }
url = html.unescape(self.params['url']) url = html.unescape(self.params["url"])
if not url.startswith(('http://', 'https://')): if not url.startswith(("http://", "https://")):
return None return None
# 创建专用会话 # 创建专用会话
@@ -171,30 +149,30 @@ class CQCode:
headers=headers, headers=headers,
timeout=15, timeout=15,
allow_redirects=True, allow_redirects=True,
stream=True # 流式传输避免大内存问题 stream=True, # 流式传输避免大内存问题
) )
# 腾讯服务器特殊状态码处理 # 腾讯服务器特殊状态码处理
if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url: if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url:
return None return None
if response.status_code != 200: if response.status_code != 200:
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
# 验证内容类型 # 验证内容类型
content_type = response.headers.get('Content-Type', '') content_type = response.headers.get("Content-Type", "")
if not content_type.startswith('image/'): if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}") raise ValueError(f"非图片内容类型: {content_type}")
# 转换为Base64 # 转换为Base64
image_base64 = base64.b64encode(response.content).decode('utf-8') image_base64 = base64.b64encode(response.content).decode("utf-8")
self.image_base64 = image_base64 self.image_base64 = image_base64
return image_base64 return image_base64
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e: except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
if retry == max_retries - 1: if retry == max_retries - 1:
print(f"\033[1;31m[致命错误]\033[0m 最终请求失败: {str(e)}") print(f"\033[1;31m[致命错误]\033[0m 最终请求失败: {str(e)}")
time.sleep(1.5 ** retry) # 指数退避 time.sleep(1.5**retry) # 指数退避
except Exception as e: except Exception as e:
print(f"\033[1;33m[未知错误]\033[0m {str(e)}") print(f"\033[1;33m[未知错误]\033[0m {str(e)}")
@@ -202,21 +180,21 @@ class CQCode:
return None return None
def translate_image(self) -> Optional[str]: def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串""" """处理图片类型的CQ码返回base64字符串"""
if 'url' not in self.params: if "url" not in self.params:
return None return None
return self.get_img() return self.get_img()
def translate_forward(self) -> Optional[List[Seg]]: def translate_forward(self) -> Optional[List[Seg]]:
"""处理转发消息返回Seg列表""" """处理转发消息返回Seg列表"""
try: try:
if 'content' not in self.params: if "content" not in self.params:
return None return None
content = self.unescape(self.params['content']) content = self.unescape(self.params["content"])
import ast import ast
try: try:
messages = ast.literal_eval(content) messages = ast.literal_eval(content)
except ValueError as e: except ValueError as e:
@@ -225,46 +203,52 @@ class CQCode:
formatted_segments = [] formatted_segments = []
for msg in messages: for msg in messages:
sender = msg.get('sender', {}) sender = msg.get("sender", {})
nickname = sender.get('card') or sender.get('nickname', '未知用户') nickname = sender.get("card") or sender.get("nickname", "未知用户")
raw_message = msg.get('raw_message', '') raw_message = msg.get("raw_message", "")
message_array = msg.get('message', []) message_array = msg.get("message", [])
if message_array and isinstance(message_array, list): if message_array and isinstance(message_array, list):
for message_part in message_array: for message_part in message_array:
if message_part.get('type') == 'forward': if message_part.get("type") == "forward":
content_seg = Seg(type='text', data='[转发消息]') content_seg = Seg(type="text", data="[转发消息]")
break break
else: else:
if raw_message: if raw_message:
from .message_cq import MessageRecvCQ from .message_cq import MessageRecvCQ
message_obj = MessageRecvCQ( message_obj = MessageRecvCQ(
user_id=msg.get('user_id', 0), user_id=msg.get("user_id", 0),
message_id=msg.get('message_id', 0), message_id=msg.get("message_id", 0),
raw_message=raw_message, raw_message=raw_message,
plain_text=raw_message, plain_text=raw_message,
group_id=msg.get('group_id', 0) group_id=msg.get("group_id", 0),
)
content_seg = Seg(
type="seglist", data=message_obj.message_segments
) )
content_seg = Seg(type='seglist', data=message_obj.message_segments)
else: else:
content_seg = Seg(type='text', data='[空消息]') content_seg = Seg(type="text", data="[空消息]")
else: else:
if raw_message: if raw_message:
from .message_cq import MessageRecvCQ from .message_cq import MessageRecvCQ
message_obj = MessageRecvCQ( message_obj = MessageRecvCQ(
user_id=msg.get('user_id', 0), user_id=msg.get("user_id", 0),
message_id=msg.get('message_id', 0), message_id=msg.get("message_id", 0),
raw_message=raw_message, raw_message=raw_message,
plain_text=raw_message, plain_text=raw_message,
group_id=msg.get('group_id', 0) group_id=msg.get("group_id", 0),
)
content_seg = Seg(
type="seglist", data=message_obj.message_segments
) )
content_seg = Seg(type='seglist', data=message_obj.message_segments)
else: else:
content_seg = Seg(type='text', data='[空消息]') content_seg = Seg(type="text", data="[空消息]")
formatted_segments.append(Seg(type='text', data=f"{nickname}: ")) formatted_segments.append(Seg(type="text", data=f"{nickname}: "))
formatted_segments.append(content_seg) formatted_segments.append(content_seg)
formatted_segments.append(Seg(type='text', data='\n')) formatted_segments.append(Seg(type="text", data="\n"))
return formatted_segments return formatted_segments
@@ -275,6 +259,7 @@ class CQCode:
def translate_reply(self) -> Optional[List[Seg]]: def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码返回Seg列表""" """处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ from .message_cq import MessageRecvCQ
if self.reply_message is None: if self.reply_message is None:
return None return None
@@ -283,17 +268,26 @@ class CQCode:
user_id=self.reply_message.sender.user_id, user_id=self.reply_message.sender.user_id,
message_id=self.reply_message.message_id, message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message), raw_message=str(self.reply_message.message),
group_id=self.group_id group_id=self.group_id,
) )
segments = [] segments = []
if message_obj.user_id == global_config.BOT_QQ: if message_obj.user_id == global_config.BOT_QQ:
segments.append(Seg(type='text', data=f"[回复 {global_config.BOT_NICKNAME} 的消息: ")) segments.append(
Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
)
)
else: else:
segments.append(Seg(type='text', data=f"[回复 {self.reply_message.sender.nickname} 的消息: ")) segments.append(
Seg(
type="text",
data=f"[回复 {self.reply_message.sender.nickname} 的消息: ",
)
)
segments.append(Seg(type='seglist', data=message_obj.message_segments)) segments.append(Seg(type="seglist", data=message_obj.message_segments))
segments.append(Seg(type='text', data="]")) segments.append(Seg(type="text", data="]"))
return segments return segments
else: else:
return None return None
@@ -301,12 +295,12 @@ class CQCode:
@staticmethod @staticmethod
def unescape(text: str) -> str: def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符""" """反转义CQ码中的特殊字符"""
return text.replace('&#44;', ',') \ return (
.replace('&#91;', '[') \ text.replace("&#44;", ",")
.replace('&#93;', ']') \ .replace("&#91;", "[")
.replace('&amp;', '&') .replace("&#93;", "]")
.replace("&amp;", "&")
)
class CQCode_tool: class CQCode_tool:
@@ -324,19 +318,15 @@ class CQCode_tool:
""" """
# 处理字典形式的CQ码 # 处理字典形式的CQ码
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text' # 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
cq_type = cq_code.get('type', 'text') cq_type = cq_code.get("type", "text")
params = {} params = {}
if cq_type == 'text': if cq_type == "text":
params['text'] = cq_code.get('data', {}).get('text', '') params["text"] = cq_code.get("data", {}).get("text", "")
else: else:
params = cq_code.get('data', {}) params = cq_code.get("data", {})
instance = CQCode( instance = CQCode(
type=cq_type, type=cq_type, params=params, group_id=0, user_id=0, reply_message=reply
params=params,
group_id=0,
user_id=0,
reply_message=reply
) )
# 进行翻译处理 # 进行翻译处理
@@ -366,10 +356,12 @@ class CQCode_tool:
# 确保使用绝对路径 # 确保使用绝对路径
abs_path = os.path.abspath(file_path) abs_path = os.path.abspath(file_path)
# 转义特殊字符 # 转义特殊字符
escaped_path = abs_path.replace('&', '&amp;') \ escaped_path = (
.replace('[', '&#91;') \ abs_path.replace("&", "&amp;")
.replace(']', '&#93;') \ .replace("[", "&#91;")
.replace(',', '&#44;') .replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@@ -383,15 +375,14 @@ class CQCode_tool:
表情包CQ码字符串 表情包CQ码字符串
""" """
# 转义base64数据 # 转义base64数据
escaped_base64 = base64_data.replace('&', '&amp;') \ escaped_base64 = (
.replace('[', '&#91;') \ base64_data.replace("&", "&amp;")
.replace(']', '&#93;') \ .replace("[", "&#91;")
.replace(',', '&#44;') .replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]" return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
cq_code_tool = CQCode_tool() cq_code_tool = CQCode_tool()

View File

@@ -1,11 +1,11 @@
import asyncio import asyncio
import base64
import hashlib
import os import os
import random import random
import time import time
import traceback import traceback
from typing import Optional, Tuple from typing import Optional, Tuple
import base64
import hashlib
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
@@ -13,9 +13,8 @@ from nonebot import get_driver
from ...common.database import Database from ...common.database import Database
from ..chat.config import global_config from ..chat.config import global_config
from ..chat.utils import get_embedding from ..chat.utils import get_embedding
from ..chat.utils_image import image_path_to_base64 from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..chat.utils_image import ImageManager
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
@@ -78,7 +77,6 @@ class EmojiManager:
if 'emoji' not in self.db.db.list_collection_names(): if 'emoji' not in self.db.db.list_collection_names():
self.db.db.create_collection('emoji') self.db.db.create_collection('emoji')
self.db.db.emoji.create_index([('embedding', '2dsphere')]) self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('tags', 1)])
self.db.db.emoji.create_index([('filename', 1)], unique=True) self.db.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str): def record_usage(self, emoji_id: str):

View File

@@ -7,7 +7,7 @@ from nonebot import get_driver
from ...common.database import Database from ...common.database import Database
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .message_cq import Message from .message import MessageRecv, MessageThinking, MessageSending
from .prompt_builder import prompt_builder from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .utils import process_llm_response from .utils import process_llm_response
@@ -18,48 +18,78 @@ config = driver.config
class ResponseGenerator: class ResponseGenerator:
def __init__(self): def __init__(self):
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000,stream=True) self.model_r1 = LLM_request(
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000) model=global_config.llm_reasoning,
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000) temperature=0.7,
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000) max_tokens=1000,
stream=True,
)
self.model_v3 = LLM_request(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000
)
self.model_r1_distill = LLM_request(
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
)
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.db = Database.get_instance() self.db = Database.get_instance()
self.current_model_type = 'r1' # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]: async def generate_response(
self, message: MessageThinking
) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数""" """根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型 # 从global_config中获取模型概率值并选择模型
rand = random.random() rand = random.random()
if rand < global_config.MODEL_R1_PROBABILITY: if rand < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = 'r1' self.current_model_type = "r1"
current_model = self.model_r1 current_model = self.model_r1
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY: elif (
self.current_model_type = 'v3' rand
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
):
self.current_model_type = "v3"
current_model = self.model_v3 current_model = self.model_v3
else: else:
self.current_model_type = 'r1_distill' self.current_model_type = "r1_distill"
current_model = self.model_r1_distill current_model = self.model_r1_distill
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++") print(
f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++"
)
model_response = await self._generate_response_with_model(message, current_model) model_response = await self._generate_response_with_model(
raw_content=model_response message, current_model
)
raw_content = model_response
if model_response: if model_response:
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}') print(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
model_response = await self._process_response(model_response) model_response = await self._process_response(model_response)
if model_response: if model_response:
return model_response, raw_content
return None, raw_content
return model_response ,raw_content async def _generate_response_with_model(
return None,raw_content self, message: MessageThinking, model: LLM_request
) -> Optional[str]:
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
"""使用指定的模型生成回复""" """使用指定的模型生成回复"""
sender_name = message.user_nickname or f"用户{message.user_id}" sender_name = (
if message.user_cardname: message.chat_stream.user_info.user_nickname
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}" or f"用户{message.chat_stream.user_info.user_id}"
)
if message.chat_stream.user_info.user_cardname:
sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}"
# 获取关系值 # 获取关系值
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0 relationship_value = (
relationship_manager.get_relationship(
message.chat_stream
).relationship_value
if relationship_manager.get_relationship(message.chat_stream)
else 0.0
)
if relationship_value != 0.0: if relationship_value != 0.0:
# print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") # print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
pass pass
@@ -69,7 +99,7 @@ class ResponseGenerator:
message_txt=message.processed_plain_text, message_txt=message.processed_plain_text,
sender_name=sender_name, sender_name=sender_name,
relationship_value=relationship_value, relationship_value=relationship_value,
group_id=message.group_id stream_id=message.chat_stream.stream_id,
) )
# 读空气模块 简化逻辑,先停用 # 读空气模块 简化逻辑,先停用
@@ -112,34 +142,51 @@ class ResponseGenerator:
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str, # def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str): # content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str, def _save_to_db(
content: str, reasoning_content: str,): self,
message: Message,
sender_name: str,
prompt: str,
prompt_check: str,
content: str,
reasoning_content: str,
):
"""保存对话记录到数据库""" """保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one({ self.db.db.reasoning_logs.insert_one(
'time': time.time(), {
'group_id': message.group_id, "time": time.time(),
'user': sender_name, "group_id": message.group_id,
'message': message.processed_plain_text, "user": sender_name,
'model': self.current_model_type, "message": message.processed_plain_text,
# 'reasoning_check': reasoning_content_check, "model": self.current_model_type,
# 'response_check': content_check, # 'reasoning_check': reasoning_content_check,
'reasoning': reasoning_content, # 'response_check': content_check,
'response': content, "reasoning": reasoning_content,
'prompt': prompt, "response": content,
'prompt_check': prompt_check "prompt": prompt,
}) "prompt_check": prompt_check,
}
)
async def _get_emotion_tags(self, content: str) -> List[str]: async def _get_emotion_tags(self, content: str) -> List[str]:
"""提取情感标签""" """提取情感标签"""
try: try:
prompt = f'''请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出 prompt = f"""请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
只输出标签就好,不要输出其他内容: 只输出标签就好,不要输出其他内容:
内容:{content} 内容:{content}
输出: 输出:
''' """
content, _ = await self.model_v25.generate_response(prompt) content, _ = await self.model_v25.generate_response(prompt)
content=content.strip() content = content.strip()
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']: if content in [
"happy",
"angry",
"sad",
"surprised",
"disgusted",
"fearful",
"neutral",
]:
return [content] return [content]
else: else:
return ["neutral"] return ["neutral"]

View File

@@ -5,7 +5,6 @@ from typing import Dict, ForwardRef, List, Optional, Union
import urllib3 import urllib3
from loguru import logger from loguru import logger
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
from .utils_image import image_manager from .utils_image import image_manager
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream from .chat_stream import ChatStream
@@ -110,23 +109,30 @@ class MessageRecv(MessageBase):
return f"[{time_str}] {name}: {self.processed_plain_text}\n" return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass @dataclass
class MessageProcessBase(MessageBase): class Message(MessageBase):
"""消息处理基类,用于处理中和发送中的消息""" chat_stream: ChatStream=None
reply: Optional['Message'] = None
detailed_plain_text: str = ""
processed_plain_text: str = ""
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
time: int,
chat_stream: ChatStream, chat_stream: ChatStream,
user_info: UserInfo,
message_segment: Optional[Seg] = None, message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None reply: Optional['MessageRecv'] = None,
detailed_plain_text: str = "",
processed_plain_text: str = "",
): ):
# 构造基础消息信息 # 构造基础消息信息
message_info = BaseMessageInfo( message_info = BaseMessageInfo(
platform=chat_stream.platform, platform=chat_stream.platform,
message_id=message_id, message_id=message_id,
time=int(time.time()), time=time,
group_info=chat_stream.group_info, group_info=chat_stream.group_info,
user_info=chat_stream.user_info user_info=user_info
) )
# 调用父类初始化 # 调用父类初始化
@@ -136,17 +142,41 @@ class MessageProcessBase(MessageBase):
raw_message=None raw_message=None
) )
# 处理状态相关属性 self.chat_stream = chat_stream
self.thinking_start_time = int(time.time())
self.thinking_time = 0
# 文本处理相关属性 # 文本处理相关属性
self.processed_plain_text = "" self.processed_plain_text = detailed_plain_text
self.detailed_plain_text = "" self.detailed_plain_text = processed_plain_text
# 回复消息 # 回复消息
self.reply = reply self.reply = reply
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None
):
# 调用父类初始化
super().__init__(
message_id=message_id,
time=int(time.time()),
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
reply=reply
)
# 处理状态相关属性
self.thinking_start_time = int(time.time())
self.thinking_time = 0
def update_thinking_time(self) -> float: def update_thinking_time(self) -> float:
"""更新思考时间""" """更新思考时间"""
self.thinking_time = round(time.time() - self.thinking_start_time, 2) self.thinking_time = round(time.time() - self.thinking_start_time, 2)
@@ -224,12 +254,14 @@ class MessageThinking(MessageProcessBase):
self, self,
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: ChatStream,
bot_user_info: UserInfo,
reply: Optional['MessageRecv'] = None reply: Optional['MessageRecv'] = None
): ):
# 调用父类初始化 # 调用父类初始化
super().__init__( super().__init__(
message_id=message_id, message_id=message_id,
chat_stream=chat_stream, chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=None, # 思考状态不需要消息段 message_segment=None, # 思考状态不需要消息段
reply=reply reply=reply
) )
@@ -237,15 +269,6 @@ class MessageThinking(MessageProcessBase):
# 思考状态特有属性 # 思考状态特有属性
self.interrupt = False self.interrupt = False
@classmethod
def from_chat_stream(cls, chat_stream: ChatStream, message_id: str, reply: Optional['MessageRecv'] = None) -> 'MessageThinking':
"""从聊天流创建思考状态消息"""
return cls(
message_id=message_id,
chat_stream=chat_stream,
reply=reply
)
@dataclass @dataclass
class MessageSending(MessageProcessBase): class MessageSending(MessageProcessBase):
"""发送状态的消息类""" """发送状态的消息类"""
@@ -254,6 +277,7 @@ class MessageSending(MessageProcessBase):
self, self,
message_id: str, message_id: str,
chat_stream: ChatStream, chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Seg, message_segment: Seg,
reply: Optional['MessageRecv'] = None, reply: Optional['MessageRecv'] = None,
is_head: bool = False, is_head: bool = False,
@@ -263,6 +287,7 @@ class MessageSending(MessageProcessBase):
super().__init__( super().__init__(
message_id=message_id, message_id=message_id,
chat_stream=chat_stream, chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment, message_segment=message_segment,
reply=reply reply=reply
) )
@@ -296,11 +321,17 @@ class MessageSending(MessageProcessBase):
message_id=thinking.message_info.message_id, message_id=thinking.message_info.message_id,
chat_stream=thinking.chat_stream, chat_stream=thinking.chat_stream,
message_segment=message_segment, message_segment=message_segment,
bot_user_info=thinking.message_info.user_info,
reply=thinking.reply, reply=thinking.reply,
is_head=is_head, is_head=is_head,
is_emoji=is_emoji is_emoji=is_emoji
) )
def to_dict(self):
ret= super().to_dict()
ret['mesage_info']['user_info']=self.chat_stream.user_info.to_dict()
return ret
@dataclass @dataclass
class MessageSet: class MessageSet:
"""消息集合类,可以存储多个发送消息""" """消息集合类,可以存储多个发送消息"""

View File

@@ -79,6 +79,21 @@ class GroupInfo:
"""转换为字典格式""" """转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None} return {k: v for k, v in asdict(self).items() if v is not None}
def from_dict(cls, data: Dict) -> 'GroupInfo':
"""从字典创建GroupInfo实例
Args:
data: 包含必要字段的字典
Returns:
GroupInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
group_id=data.get('group_id'),
group_name=data.get('group_name',None)
)
@dataclass @dataclass
class UserInfo: class UserInfo:
"""用户信息类""" """用户信息类"""
@@ -91,6 +106,22 @@ class UserInfo:
"""转换为字典格式""" """转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None} return {k: v for k, v in asdict(self).items() if v is not None}
def from_dict(cls, data: Dict) -> 'UserInfo':
"""从字典创建UserInfo实例
Args:
data: 包含必要字段的字典
Returns:
UserInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
user_id=data.get('user_id'),
user_nickname=data.get('user_nickname',None),
user_cardname=data.get('user_cardname',None)
)
@dataclass @dataclass
class BaseMessageInfo: class BaseMessageInfo:
"""消息信息类""" """消息信息类"""
@@ -147,7 +178,7 @@ class MessageBase:
""" """
message_info = BaseMessageInfo(**data.get('message_info', {})) message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {})) message_segment = Seg(**data.get('message_segment', {}))
raw_message = data.get('raw_message') raw_message = data.get('raw_message',None)
return cls( return cls(
message_info=message_info, message_info=message_info,
message_segment=message_segment, message_segment=message_segment,

View File

@@ -139,27 +139,24 @@ class MessageSendCQ(MessageCQ):
def __init__( def __init__(
self, self,
message_id: int, data: Dict
user_id: int,
message_segment: Seg,
group_id: Optional[int] = None,
reply_to_message_id: Optional[int] = None,
platform: str = "qq"
): ):
# 调用父类初始化 # 调用父类初始化
super().__init__(message_id, user_id, group_id, platform) message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {}))
super().__init__(
message_info.message_id,
message_info.user_info.user_id,
message_info.group_info.group_id if message_info.group_info else None,
message_info.platform)
self.message_segment = message_segment self.message_segment = message_segment
self.raw_message = self._generate_raw_message(reply_to_message_id) self.raw_message = self._generate_raw_message()
def _generate_raw_message(self, reply_to_message_id: Optional[int] = None) -> str: def _generate_raw_message(self, ) -> str:
"""将Seg对象转换为raw_message""" """将Seg对象转换为raw_message"""
segments = [] segments = []
# 添加回复消息
if reply_to_message_id:
segments.append(cq_code_tool.create_reply_cq(reply_to_message_id))
# 处理消息段 # 处理消息段
if self.message_segment.type == 'seglist': if self.message_segment.type == 'seglist':
for seg in self.message_segment.data: for seg in self.message_segment.data:

View File

@@ -29,13 +29,12 @@ class Message_Sender:
) -> None: ) -> None:
"""发送消息""" """发送消息"""
if isinstance(message, MessageSending): if isinstance(message, MessageSending):
message_json = message.to_dict()
message_send=MessageSendCQ( message_send=MessageSendCQ(
message_id=message.message_id, data=message_json
user_id=message.message_info.user_info.user_id,
message_segment=message.message_segment,
reply=message.reply
) )
if message.message_info.group_info:
if message_send.message_info.group_info:
try: try:
await self._current_bot.send_group_msg( await self._current_bot.send_group_msg(
group_id=message.message_info.group_info.group_id, group_id=message.message_info.group_info.group_id,

View File

@@ -8,6 +8,7 @@ from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
from .config import global_config from .config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text from .utils import get_embedding, get_recent_group_detailed_plain_text
from .chat_stream import ChatStream, chat_manager
class PromptBuilder: class PromptBuilder:
@@ -22,7 +23,7 @@ class PromptBuilder:
message_txt: str, message_txt: str,
sender_name: str = "某人", sender_name: str = "某人",
relationship_value: float = 0.0, relationship_value: float = 0.0,
group_id: Optional[int] = None) -> tuple[str, str]: stream_id: Optional[int] = None) -> tuple[str, str]:
"""构建prompt """构建prompt
Args: Args:
@@ -72,11 +73,17 @@ class PromptBuilder:
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}") print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}")
# 获取聊天上下文 # 获取聊天上下文
chat_in_group=True
chat_talking_prompt = '' chat_talking_prompt = ''
if group_id: if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_stream=chat_manager.get_stream(stream_id)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
else:
chat_in_group=False
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
@@ -112,8 +119,10 @@ class PromptBuilder:
#激活prompt构建 #激活prompt构建
activate_prompt = '' activate_prompt = ''
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}" if chat_in_group:
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
else:
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
#检测机器人相关词汇 #检测机器人相关词汇
bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人'] bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords) is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
@@ -129,16 +138,20 @@ class PromptBuilder:
probability_3 = global_config.PERSONALITY_3 probability_3 = global_config.PERSONALITY_3
prompt_personality = '' prompt_personality = ''
personality_choice = random.random() personality_choice = random.random()
if chat_in_group:
prompt_in_group=f"你正在浏览{chat_stream.platform}"
else:
prompt_in_group=f"你正在{chat_stream.platform}上和{sender_name}私聊"
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群,{promt_info_prompt}, prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}{prompt_in_group},{promt_info_prompt},
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt} 现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。''' 请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
elif personality_choice < probability_1 + probability_2: # 第二种人格 elif personality_choice < probability_1 + probability_2: # 第二种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}, 你正在浏览qq群{promt_info_prompt}, prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}{prompt_in_group}{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt} 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
请你表达自己的见解和观点。可以有个性。''' 请你表达自己的见解和观点。可以有个性。'''
else: # 第三种人格 else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt}, prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}{prompt_in_group}{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt} 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
请你表达自己的见解和观点。可以有个性。''' 请你表达自己的见解和观点。可以有个性。'''

View File

@@ -16,17 +16,16 @@ class Impression:
class Relationship: class Relationship:
user_id: int = None user_id: int = None
platform: str = None platform: str = None
platform: str = None
gender: str = None gender: str = None
age: int = None age: int = None
nickname: str = None nickname: str = None
relationship_value: float = None relationship_value: float = None
saved = False saved = False
def __init__(self, chat:ChatStream,data:dict): def __init__(self, chat:ChatStream=None,data:dict=None):
self.user_id=chat.user_info.user_id self.user_id=chat.user_info.user_id if chat.user_info else data.get('user_id',0)
self.platform=chat.platform self.platform=chat.platform if chat.user_info else data.get('platform','')
self.nickname=chat.user_info.user_nickname self.nickname=chat.user_info.user_nickname if chat.user_info else data.get('nickname','')
self.relationship_value=data.get('relationship_value',0) self.relationship_value=data.get('relationship_value',0)
self.age=data.get('age',0) self.age=data.get('age',0)
self.gender=data.get('gender','') self.gender=data.get('gender','')
@@ -35,7 +34,6 @@ class Relationship:
class RelationshipManager: class RelationshipManager:
def __init__(self): def __init__(self):
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键 self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
async def update_relationship(self, async def update_relationship(self,
chat_stream:ChatStream, chat_stream:ChatStream,
@@ -43,9 +41,7 @@ class RelationshipManager:
**kwargs) -> Optional[Relationship]: **kwargs) -> Optional[Relationship]:
"""更新或创建关系 """更新或创建关系
Args: Args:
user_id: 用户ID可选如果提供user_info则不需要 chat_stream: 聊天流对象
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
data: 字典格式的数据(可选) data: 字典格式的数据(可选)
**kwargs: 其他参数 **kwargs: 其他参数
Returns: Returns:
@@ -66,44 +62,18 @@ class RelationshipManager:
# 检查是否在内存中已存在 # 检查是否在内存中已存在
relationship = self.relationships.get(key) relationship = self.relationships.get(key)
relationship = self.relationships.get(key)
if relationship: if relationship:
# 如果存在,更新现有对象 # 如果存在,更新现有对象
if isinstance(data, dict): if isinstance(data, dict):
for k, value in data.items(): for k, value in data.items():
if hasattr(relationship, k) and value is not None: if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value) setattr(relationship, k, value)
for k, value in data.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else:
for k, value in kwargs.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
for k, value in kwargs.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else: else:
# 如果不存在,创建新对象 # 如果不存在,创建新对象
if user_info is not None: if chat_stream.user_info is not None:
relationship = Relationship(user_info=user_info, **kwargs) relationship = Relationship(chat=chat_stream, **kwargs)
elif isinstance(data, dict):
data['platform'] = platform
relationship = Relationship(user_id=user_id, data=data)
else: else:
kwargs['platform'] = platform raise ValueError("必须提供user_id或user_info")
kwargs['user_id'] = user_id
relationship = Relationship(**kwargs)
self.relationships[key] = relationship
if user_info is not None:
relationship = Relationship(user_info=user_info, **kwargs)
elif isinstance(data, dict):
data['platform'] = platform
relationship = Relationship(user_id=user_id, data=data)
else:
kwargs['platform'] = platform
kwargs['user_id'] = user_id
relationship = Relationship(**kwargs)
self.relationships[key] = relationship self.relationships[key] = relationship
# 保存到数据库 # 保存到数据库
@@ -113,36 +83,7 @@ class RelationshipManager:
return relationship return relationship
async def update_relationship_value(self, async def update_relationship_value(self,
user_id: int = None, chat_stream:ChatStream,
platform: str = None,
user_info: UserInfo = None,
**kwargs) -> Optional[Relationship]:
"""更新关系值
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
async def update_relationship_value(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None,
**kwargs) -> Optional[Relationship]: **kwargs) -> Optional[Relationship]:
"""更新关系值 """更新关系值
Args: Args:
@@ -154,6 +95,7 @@ class RelationshipManager:
Relationship: 关系对象 Relationship: 关系对象
""" """
# 确定user_id和platform # 确定user_id和platform
user_info = chat_stream.user_info
if user_info is not None: if user_info is not None:
user_id = user_info.user_id user_id = user_info.user_id
platform = user_info.platform or 'qq' platform = user_info.platform or 'qq'
@@ -168,10 +110,7 @@ class RelationshipManager:
# 检查是否在内存中已存在 # 检查是否在内存中已存在
relationship = self.relationships.get(key) relationship = self.relationships.get(key)
relationship = self.relationships.get(key)
if relationship: if relationship:
for k, value in kwargs.items():
if k == 'relationship_value':
for k, value in kwargs.items(): for k, value in kwargs.items():
if k == 'relationship_value': if k == 'relationship_value':
relationship.relationship_value += value relationship.relationship_value += value
@@ -181,43 +120,12 @@ class RelationshipManager:
else: else:
# 如果不存在且提供了user_info则创建新的关系 # 如果不存在且提供了user_info则创建新的关系
if user_info is not None: if user_info is not None:
return await self.update_relationship(user_info=user_info, **kwargs) return await self.update_relationship(chat_stream=chat_stream, **kwargs)
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
# 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(user_info=user_info, **kwargs)
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新") print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
return None return None
def get_relationship(self, def get_relationship(self,
user_id: int = None, chat_stream:ChatStream) -> Optional[Relationship]:
platform: str = None,
user_info: UserInfo = None) -> Optional[Relationship]:
"""获取用户关系对象
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key]
def get_relationship(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> Optional[Relationship]:
"""获取用户关系对象 """获取用户关系对象
Args: Args:
user_id: 用户ID可选如果提供user_info则不需要 user_id: 用户ID可选如果提供user_info则不需要
@@ -227,6 +135,8 @@ class RelationshipManager:
Relationship: 关系对象 Relationship: 关系对象
""" """
# 确定user_id和platform # 确定user_id和platform
user_info = chat_stream.user_info
platform = chat_stream.user_info.platform or 'qq'
if user_info is not None: if user_info is not None:
user_id = user_info.user_id user_id = user_info.user_id
platform = user_info.platform or 'qq' platform = user_info.platform or 'qq'
@@ -248,18 +158,10 @@ class RelationshipManager:
if 'platform' not in data: if 'platform' not in data:
data['platform'] = 'qq' data['platform'] = 'qq'
rela = Relationship(data=data)
"""从数据库加载或创建新的关系对象"""
# 确保data中有platform字段如果没有则默认为'qq'
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data) rela = Relationship(data=data)
rela.saved = True rela.saved = True
key = (rela.user_id, rela.platform) key = (rela.user_id, rela.platform)
self.relationships[key] = rela self.relationships[key] = rela
key = (rela.user_id, rela.platform)
self.relationships[key] = rela
return rela return rela
async def load_all_relationships(self): async def load_all_relationships(self):
@@ -277,7 +179,6 @@ class RelationshipManager:
# 依次加载每条记录 # 依次加载每条记录
for data in all_relationships: for data in all_relationships:
await self.load_relationship(data) await self.load_relationship(data)
await self.load_relationship(data)
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录") print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
while True: while True:
@@ -288,19 +189,15 @@ class RelationshipManager:
async def _save_all_relationships(self): async def _save_all_relationships(self):
"""将所有关系数据保存到数据库""" """将所有关系数据保存到数据库"""
# 保存所有关系数据 # 保存所有关系数据
for (userid, platform), relationship in self.relationships.items():
for (userid, platform), relationship in self.relationships.items(): for (userid, platform), relationship in self.relationships.items():
if not relationship.saved: if not relationship.saved:
relationship.saved = True relationship.saved = True
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
async def storage_relationship(self, relationship: Relationship):
"""将关系记录存储到数据库中"""
async def storage_relationship(self, relationship: Relationship): async def storage_relationship(self, relationship: Relationship):
"""将关系记录存储到数据库中""" """将关系记录存储到数据库中"""
user_id = relationship.user_id user_id = relationship.user_id
platform = relationship.platform platform = relationship.platform
platform = relationship.platform
nickname = relationship.nickname nickname = relationship.nickname
relationship_value = relationship.relationship_value relationship_value = relationship.relationship_value
gender = relationship.gender gender = relationship.gender
@@ -309,10 +206,8 @@ class RelationshipManager:
db = Database.get_instance() db = Database.get_instance()
db.db.relationships.update_one( db.db.relationships.update_one(
{'user_id': user_id, 'platform': platform},
{'user_id': user_id, 'platform': platform}, {'user_id': user_id, 'platform': platform},
{'$set': { {'$set': {
'platform': platform,
'platform': platform, 'platform': platform,
'nickname': nickname, 'nickname': nickname,
'relationship_value': relationship_value, 'relationship_value': relationship_value,
@@ -323,27 +218,6 @@ class RelationshipManager:
upsert=True upsert=True
) )
def get_name(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> str:
"""获取用户昵称
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
str: 用户昵称
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
def get_name(self, def get_name(self,
user_id: int = None, user_id: int = None,
@@ -370,11 +244,6 @@ class RelationshipManager:
# 确保user_id是整数类型 # 确保user_id是整数类型
user_id = int(user_id) user_id = int(user_id)
key = (user_id, platform) key = (user_id, platform)
if key in self.relationships:
return self.relationships[key].nickname
elif user_info is not None:
return user_info.user_nickname or user_info.user_cardname or "某人"
key = (user_id, platform)
if key in self.relationships: if key in self.relationships:
return self.relationships[key].nickname return self.relationships[key].nickname
elif user_info is not None: elif user_info is not None:

View File

@@ -18,8 +18,9 @@ class MessageStorage:
"time": message.message_info.time, "time": message.message_info.time,
"chat_id":chat_stream.stream_id, "chat_id":chat_stream.stream_id,
"chat_info": chat_stream.to_dict(), "chat_info": chat_stream.to_dict(),
"detailed_plain_text": message.detailed_plain_text, "user_info": message.message_info.user_info.to_dict(),
"processed_plain_text": message.processed_plain_text, "processed_plain_text": message.processed_plain_text,
"detailed_plain_text": message.detailed_plain_text,
"topic": topic, "topic": topic,
} }
self.db.db.messages.insert_one(message_data) self.db.db.messages.insert_one(message_data)

View File

@@ -11,7 +11,9 @@ from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config from .config import global_config
from .message_cq import Message from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo
from .chat_stream import ChatStream
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
@@ -32,7 +34,7 @@ def db_message_to_str(message_dict: Dict) -> str:
return result return result
def is_mentioned_bot_in_message(message: Message) -> bool: def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
"""检查消息是否提到了机器人""" """检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME] keywords = [global_config.BOT_NICKNAME]
for keyword in keywords: for keyword in keywords:
@@ -41,15 +43,6 @@ def is_mentioned_bot_in_message(message: Message) -> bool:
return False return False
def is_mentioned_bot_in_txt(message: str) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
for keyword in keywords:
if keyword in message:
return True
return False
async def get_embedding(text): async def get_embedding(text):
"""获取文本的embedding向量""" """获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding) llm = LLM_request(model=global_config.embedding)
@@ -84,10 +77,10 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
if closest_record and closest_record.get('memorized', 0) < 4: if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time'] closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid chat_id = closest_record['chat_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同 # 获取该时间戳之后的length条消息且groupid相同
chat_records = list(db.db.messages.find( chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id} {"time": {"$gt": closest_time}, "chat_id": chat_id}
).sort('time', 1).limit(length)) ).sort('time', 1).limit(length))
# 更新每条消息的memorized属性 # 更新每条消息的memorized属性
@@ -111,7 +104,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
return '' return ''
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录
Args: Args:
@@ -125,35 +118,28 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息 # 从数据库获取最近消息
recent_messages = list(db.db.messages.find( recent_messages = list(db.db.messages.find(
{"group_id": group_id}, {"chat_id": chat_id},
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
).sort("time", -1).limit(limit)) ).sort("time", -1).limit(limit))
if not recent_messages: if not recent_messages:
return [] return []
# 转换为 Message对象列表 # 转换为 Message对象列表
from .message_cq import Message
message_objects = [] message_objects = []
for msg_data in recent_messages: for msg_data in recent_messages:
try: try:
chat_info=msg_data.get("chat_info",{})
chat_stream=ChatStream.from_dict(chat_info)
user_info=msg_data.get("user_info",{})
user_info=UserInfo.from_dict(user_info)
msg = Message( msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"], message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"], chat_stream=chat_stream,
time=msg_data["time"],
user_info=user_info,
processed_plain_text=msg_data.get("processed_text", ""), processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id detailed_plain_text=msg_data.get("detailed_plain_text", "")
) )
await msg.initialize()
message_objects.append(msg) message_objects.append(msg)
except KeyError: except KeyError:
print("[WARNING] 数据库中存在无效的消息") print("[WARNING] 数据库中存在无效的消息")
@@ -164,13 +150,14 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
return message_objects return message_objects
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False): def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.db.messages.find( recent_messages = list(db.db.messages.find(
{"group_id": group_id}, {"chat_id": chat_stream_id},
{ {
"time": 1, # 返回时间字段 "time": 1, # 返回时间字段
"user_id": 1, # 返回用户ID字段 "chat_id":1,
"user_nickname": 1, # 返回用户昵称字段 "chat_info":1,
"user_info": 1,
"message_id": 1, # 返回消息ID字段 "message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1 # 返回处理后的文本字段 "detailed_plain_text": 1 # 返回处理后的文本字段
} }