feat: 重构完成开始测试debug
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from random import random
|
from random import random
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
|
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
|
||||||
|
|
||||||
@@ -11,25 +10,18 @@ from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
|
|||||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||||
from .llm_generator import ResponseGenerator
|
from .llm_generator import ResponseGenerator
|
||||||
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||||
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
|
||||||
from .message_cq import (
|
from .message_cq import (
|
||||||
MessageRecvCQ,
|
MessageRecvCQ,
|
||||||
MessageSendCQ,
|
|
||||||
)
|
|
||||||
from .chat_stream import chat_manager
|
|
||||||
MessageRecvCQ,
|
|
||||||
MessageSendCQ,
|
|
||||||
)
|
)
|
||||||
from .chat_stream import chat_manager
|
from .chat_stream import chat_manager
|
||||||
|
|
||||||
from .message_sender import message_manager # 导入新的消息管理器
|
from .message_sender import message_manager # 导入新的消息管理器
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
|
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
|
||||||
from .utils_image import image_path_to_base64
|
from .utils_image import image_path_to_base64
|
||||||
from .utils_image import image_path_to_base64
|
|
||||||
from .willing_manager import willing_manager # 导入意愿管理器
|
from .willing_manager import willing_manager # 导入意愿管理器
|
||||||
from .message_base import UserInfo, GroupInfo, Seg
|
from .message_base import UserInfo, GroupInfo, Seg
|
||||||
from .message_base import UserInfo, GroupInfo, Seg
|
|
||||||
|
|
||||||
class ChatBot:
|
class ChatBot:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -53,24 +45,21 @@ class ChatBot:
|
|||||||
|
|
||||||
self.bot = bot # 更新 bot 实例
|
self.bot = bot # 更新 bot 实例
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
group_info = await bot.get_group_info(group_id=event.group_id)
|
group_info = await bot.get_group_info(group_id=event.group_id)
|
||||||
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||||
|
|
||||||
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
# 白名单设定由nontbot侧完成
|
||||||
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
if event.group_id:
|
||||||
|
if event.group_id not in global_config.talk_allowed_groups:
|
||||||
|
return
|
||||||
|
if event.user_id in global_config.ban_user_id:
|
||||||
|
return
|
||||||
|
|
||||||
message_cq=MessageRecvCQ(
|
|
||||||
message_cq=MessageRecvCQ(
|
message_cq=MessageRecvCQ(
|
||||||
message_id=event.message_id,
|
message_id=event.message_id,
|
||||||
user_id=event.user_id,
|
user_id=event.user_id,
|
||||||
raw_message=str(event.original_message),
|
raw_message=str(event.original_message),
|
||||||
group_id=event.group_id,
|
group_id=event.group_id,
|
||||||
user_id=event.user_id,
|
|
||||||
raw_message=str(event.original_message),
|
|
||||||
group_id=event.group_id,
|
|
||||||
reply_message=event.reply,
|
reply_message=event.reply,
|
||||||
platform='qq'
|
platform='qq'
|
||||||
)
|
)
|
||||||
@@ -78,37 +67,26 @@ class ChatBot:
|
|||||||
|
|
||||||
# 进入maimbot
|
# 进入maimbot
|
||||||
message=MessageRecv(**message_json)
|
message=MessageRecv(**message_json)
|
||||||
await message.process()
|
|
||||||
groupinfo=message.message_info.group_info
|
groupinfo=message.message_info.group_info
|
||||||
userinfo=message.message_info.user_info
|
userinfo=message.message_info.user_info
|
||||||
messageinfo=message.message_info
|
messageinfo=message.message_info
|
||||||
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
|
|
||||||
|
|
||||||
# 消息过滤,涉及到config有待更新
|
# 消息过滤,涉及到config有待更新
|
||||||
if groupinfo:
|
|
||||||
if groupinfo.group_id not in global_config.talk_allowed_groups:
|
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
|
||||||
return
|
await relationship_manager.update_relationship(chat_stream=chat,)
|
||||||
else:
|
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value = 0.5)
|
||||||
if userinfo:
|
|
||||||
if userinfo.user_id in []:
|
await message.process()
|
||||||
pass
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
return
|
|
||||||
if userinfo.user_id in global_config.ban_user_id:
|
|
||||||
return
|
|
||||||
# 过滤词
|
# 过滤词
|
||||||
for word in global_config.ban_words:
|
for word in global_config.ban_words:
|
||||||
if word in message.processed_plain_text:
|
|
||||||
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
|
|
||||||
if word in message.processed_plain_text:
|
if word in message.processed_plain_text:
|
||||||
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
|
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
|
||||||
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered")
|
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered")
|
||||||
return
|
return
|
||||||
|
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -130,20 +108,13 @@ class ChatBot:
|
|||||||
is_emoji=message.is_emoji,
|
is_emoji=message.is_emoji,
|
||||||
interested_rate=interested_rate
|
interested_rate=interested_rate
|
||||||
)
|
)
|
||||||
current_willing = willing_manager.get_willing(
|
current_willing = willing_manager.get_willing(chat_stream=chat)
|
||||||
chat_stream=chat
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\033[1;32m[{current_time}][{chat.group_info.group_name}]{chat.user_info.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
|
print(f"\033[1;32m[{current_time}][{chat.group_info.group_name}]{chat.user_info.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
|
||||||
|
|
||||||
response = None
|
response = None
|
||||||
|
|
||||||
if random() < reply_probability:
|
if random() < reply_probability:
|
||||||
bot_user_info=UserInfo(
|
|
||||||
user_id=global_config.BOT_QQ,
|
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
|
||||||
platform=messageinfo.platform
|
|
||||||
)
|
|
||||||
bot_user_info=UserInfo(
|
bot_user_info=UserInfo(
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.BOT_QQ,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.BOT_NICKNAME,
|
||||||
@@ -151,22 +122,16 @@ class ChatBot:
|
|||||||
)
|
)
|
||||||
tinking_time_point = round(time.time(), 2)
|
tinking_time_point = round(time.time(), 2)
|
||||||
think_id = 'mt' + str(tinking_time_point)
|
think_id = 'mt' + str(tinking_time_point)
|
||||||
thinking_message = MessageThinking.from_chat_stream(
|
thinking_message = MessageThinking(
|
||||||
chat_stream=chat,
|
|
||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
reply=message
|
|
||||||
)
|
|
||||||
thinking_message = MessageThinking.from_chat_stream(
|
|
||||||
chat_stream=chat,
|
chat_stream=chat,
|
||||||
message_id=think_id,
|
bot_user_info=bot_user_info,
|
||||||
reply=message
|
reply=message
|
||||||
)
|
)
|
||||||
|
|
||||||
message_manager.add_message(thinking_message)
|
message_manager.add_message(thinking_message)
|
||||||
|
|
||||||
willing_manager.change_reply_willing_sent(
|
willing_manager.change_reply_willing_sent(chat)
|
||||||
chat_stream=chat
|
|
||||||
)
|
|
||||||
|
|
||||||
response,raw_content = await self.gpt.generate_response(message)
|
response,raw_content = await self.gpt.generate_response(message)
|
||||||
|
|
||||||
@@ -201,18 +166,11 @@ class ChatBot:
|
|||||||
accu_typing_time += typing_time
|
accu_typing_time += typing_time
|
||||||
timepoint = tinking_time_point + accu_typing_time
|
timepoint = tinking_time_point + accu_typing_time
|
||||||
|
|
||||||
message_segment = Seg(type='text', data=msg)
|
|
||||||
bot_message = MessageSending(
|
|
||||||
message_segment = Seg(type='text', data=msg)
|
message_segment = Seg(type='text', data=msg)
|
||||||
bot_message = MessageSending(
|
bot_message = MessageSending(
|
||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
chat_stream=chat,
|
chat_stream=chat,
|
||||||
message_segment=message_segment,
|
bot_user_info=bot_user_info,
|
||||||
reply=message,
|
|
||||||
is_head=not mark_head,
|
|
||||||
is_emoji=False
|
|
||||||
)
|
|
||||||
chat_stream=chat,
|
|
||||||
message_segment=message_segment,
|
message_segment=message_segment,
|
||||||
reply=message,
|
reply=message,
|
||||||
is_head=not mark_head,
|
is_head=not mark_head,
|
||||||
@@ -235,7 +193,6 @@ class ChatBot:
|
|||||||
if emoji_raw != None:
|
if emoji_raw != None:
|
||||||
emoji_path,discription = emoji_raw
|
emoji_path,discription = emoji_raw
|
||||||
|
|
||||||
emoji_cq = image_path_to_base64(emoji_path)
|
|
||||||
emoji_cq = image_path_to_base64(emoji_path)
|
emoji_cq = image_path_to_base64(emoji_path)
|
||||||
|
|
||||||
if random() < 0.5:
|
if random() < 0.5:
|
||||||
@@ -247,15 +204,7 @@ class ChatBot:
|
|||||||
bot_message = MessageSending(
|
bot_message = MessageSending(
|
||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
chat_stream=chat,
|
chat_stream=chat,
|
||||||
message_segment=message_segment,
|
bot_user_info=bot_user_info,
|
||||||
reply=message,
|
|
||||||
is_head=False,
|
|
||||||
is_emoji=True
|
|
||||||
)
|
|
||||||
message_segment = Seg(type='emoji', data=emoji_cq)
|
|
||||||
bot_message = MessageSending(
|
|
||||||
message_id=think_id,
|
|
||||||
chat_stream=chat,
|
|
||||||
message_segment=message_segment,
|
message_segment=message_segment,
|
||||||
reply=message,
|
reply=message,
|
||||||
is_head=False,
|
is_head=False,
|
||||||
@@ -273,20 +222,12 @@ class ChatBot:
|
|||||||
'fearful': -0.7,
|
'fearful': -0.7,
|
||||||
'neutral': 0.1
|
'neutral': 0.1
|
||||||
}
|
}
|
||||||
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
|
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value=valuedict[emotion[0]])
|
||||||
# 使用情绪管理器更新情绪
|
# 使用情绪管理器更新情绪
|
||||||
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
|
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
|
||||||
|
|
||||||
willing_manager.change_reply_willing_after_sent(
|
willing_manager.change_reply_willing_after_sent(
|
||||||
platform=messageinfo.platform,
|
chat_stream=chat
|
||||||
user_info=userinfo,
|
|
||||||
group_info=groupinfo
|
|
||||||
)
|
|
||||||
|
|
||||||
willing_manager.change_reply_willing_after_sent(
|
|
||||||
platform=messageinfo.platform,
|
|
||||||
user_info=userinfo,
|
|
||||||
group_info=groupinfo
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建全局ChatBot实例
|
# 创建全局ChatBot实例
|
||||||
|
|||||||
@@ -1,53 +1,65 @@
|
|||||||
import time
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional, Dict, Tuple
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from .message_base import UserInfo, GroupInfo
|
from .message_base import GroupInfo, UserInfo
|
||||||
|
|
||||||
|
|
||||||
class ChatStream:
|
class ChatStream:
|
||||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||||
def __init__(self,
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
platform: str,
|
platform: str,
|
||||||
user_info: UserInfo,
|
user_info: UserInfo,
|
||||||
group_info: Optional[GroupInfo] = None,
|
group_info: Optional[GroupInfo] = None,
|
||||||
data: dict = None):
|
data: dict = None,
|
||||||
|
):
|
||||||
self.stream_id = stream_id
|
self.stream_id = stream_id
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
self.user_info = user_info
|
self.user_info = user_info
|
||||||
self.group_info = group_info
|
self.group_info = group_info
|
||||||
self.create_time = data.get('create_time', int(time.time())) if data else int(time.time())
|
self.create_time = (
|
||||||
self.last_active_time = data.get('last_active_time', self.create_time) if data else self.create_time
|
data.get("create_time", int(time.time())) if data else int(time.time())
|
||||||
|
)
|
||||||
|
self.last_active_time = (
|
||||||
|
data.get("last_active_time", self.create_time) if data else self.create_time
|
||||||
|
)
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
result = {
|
result = {
|
||||||
'stream_id': self.stream_id,
|
"stream_id": self.stream_id,
|
||||||
'platform': self.platform,
|
"platform": self.platform,
|
||||||
'user_info': self.user_info.to_dict() if self.user_info else None,
|
"user_info": self.user_info.to_dict() if self.user_info else None,
|
||||||
'group_info': self.group_info.to_dict() if self.group_info else None,
|
"group_info": self.group_info.to_dict() if self.group_info else None,
|
||||||
'create_time': self.create_time,
|
"create_time": self.create_time,
|
||||||
'last_active_time': self.last_active_time
|
"last_active_time": self.last_active_time,
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict) -> 'ChatStream':
|
def from_dict(cls, data: dict) -> "ChatStream":
|
||||||
"""从字典创建实例"""
|
"""从字典创建实例"""
|
||||||
user_info = UserInfo(**data.get('user_info', {})) if data.get('user_info') else None
|
user_info = (
|
||||||
group_info = GroupInfo(**data.get('group_info', {})) if data.get('group_info') else None
|
UserInfo(**data.get("user_info", {})) if data.get("user_info") else None
|
||||||
|
)
|
||||||
|
group_info = (
|
||||||
|
GroupInfo(**data.get("group_info", {})) if data.get("group_info") else None
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
stream_id=data['stream_id'],
|
stream_id=data["stream_id"],
|
||||||
platform=data['platform'],
|
platform=data["platform"],
|
||||||
user_info=user_info,
|
user_info=user_info,
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
data=data
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_active_time(self):
|
def update_active_time(self):
|
||||||
@@ -58,6 +70,7 @@ class ChatStream:
|
|||||||
|
|
||||||
class ChatManager:
|
class ChatManager:
|
||||||
"""聊天管理器,管理所有聊天流"""
|
"""聊天管理器,管理所有聊天流"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
_initialized = False
|
_initialized = False
|
||||||
|
|
||||||
@@ -97,33 +110,32 @@ class ChatManager:
|
|||||||
|
|
||||||
def _ensure_collection(self):
|
def _ensure_collection(self):
|
||||||
"""确保数据库集合存在并创建索引"""
|
"""确保数据库集合存在并创建索引"""
|
||||||
if 'chat_streams' not in self.db.db.list_collection_names():
|
if "chat_streams" not in self.db.db.list_collection_names():
|
||||||
self.db.db.create_collection('chat_streams')
|
self.db.db.create_collection("chat_streams")
|
||||||
# 创建索引
|
# 创建索引
|
||||||
self.db.db.chat_streams.create_index([('stream_id', 1)], unique=True)
|
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
|
||||||
self.db.db.chat_streams.create_index([
|
self.db.db.chat_streams.create_index(
|
||||||
('platform', 1),
|
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
|
||||||
('user_info.user_id', 1),
|
)
|
||||||
('group_info.group_id', 1)
|
|
||||||
])
|
|
||||||
|
|
||||||
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
def _generate_stream_id(
|
||||||
|
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||||
|
) -> str:
|
||||||
"""生成聊天流唯一ID"""
|
"""生成聊天流唯一ID"""
|
||||||
# 组合关键信息
|
# 组合关键信息
|
||||||
components = [
|
components = [
|
||||||
platform,
|
platform,
|
||||||
str(user_info.user_id),
|
str(user_info.user_id),
|
||||||
str(group_info.group_id) if group_info else 'private'
|
str(group_info.group_id) if group_info else "private",
|
||||||
]
|
]
|
||||||
|
|
||||||
# 使用MD5生成唯一ID
|
# 使用MD5生成唯一ID
|
||||||
key = '_'.join(components)
|
key = "_".join(components)
|
||||||
return hashlib.md5(key.encode()).hexdigest()
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
async def get_or_create_stream(self,
|
async def get_or_create_stream(
|
||||||
platform: str,
|
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||||
user_info: UserInfo,
|
) -> ChatStream:
|
||||||
group_info: Optional[GroupInfo] = None) -> ChatStream:
|
|
||||||
"""获取或创建聊天流
|
"""获取或创建聊天流
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -148,7 +160,7 @@ class ChatManager:
|
|||||||
return stream
|
return stream
|
||||||
|
|
||||||
# 检查数据库中是否存在
|
# 检查数据库中是否存在
|
||||||
data = self.db.db.chat_streams.find_one({'stream_id': stream_id})
|
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
|
||||||
if data:
|
if data:
|
||||||
stream = ChatStream.from_dict(data)
|
stream = ChatStream.from_dict(data)
|
||||||
# 更新用户信息和群组信息
|
# 更新用户信息和群组信息
|
||||||
@@ -162,7 +174,7 @@ class ChatManager:
|
|||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
platform=platform,
|
platform=platform,
|
||||||
user_info=user_info,
|
user_info=user_info,
|
||||||
group_info=group_info
|
group_info=group_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存到内存和数据库
|
# 保存到内存和数据库
|
||||||
@@ -174,10 +186,9 @@ class ChatManager:
|
|||||||
"""通过stream_id获取聊天流"""
|
"""通过stream_id获取聊天流"""
|
||||||
return self.streams.get(stream_id)
|
return self.streams.get(stream_id)
|
||||||
|
|
||||||
def get_stream_by_info(self,
|
def get_stream_by_info(
|
||||||
platform: str,
|
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||||
user_info: UserInfo,
|
) -> Optional[ChatStream]:
|
||||||
group_info: Optional[GroupInfo] = None) -> Optional[ChatStream]:
|
|
||||||
"""通过信息获取聊天流"""
|
"""通过信息获取聊天流"""
|
||||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||||
return self.streams.get(stream_id)
|
return self.streams.get(stream_id)
|
||||||
@@ -186,9 +197,7 @@ class ChatManager:
|
|||||||
"""保存聊天流到数据库"""
|
"""保存聊天流到数据库"""
|
||||||
if not stream.saved:
|
if not stream.saved:
|
||||||
self.db.db.chat_streams.update_one(
|
self.db.db.chat_streams.update_one(
|
||||||
{'stream_id': stream.stream_id},
|
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
|
||||||
{'$set': stream.to_dict()},
|
|
||||||
upsert=True
|
|
||||||
)
|
)
|
||||||
stream.saved = True
|
stream.saved = True
|
||||||
|
|
||||||
|
|||||||
@@ -3,23 +3,22 @@ import html
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, List, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
# 解析各种CQ码
|
# 解析各种CQ码
|
||||||
# 包含CQ码类
|
# 包含CQ码类
|
||||||
import urllib3
|
import urllib3
|
||||||
|
from loguru import logger
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
from urllib3.util import create_urllib3_context
|
from urllib3.util import create_urllib3_context
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .mapper import emojimapper
|
from .mapper import emojimapper
|
||||||
from .utils_image import image_manager
|
|
||||||
from .utils_user import get_user_nickname
|
|
||||||
from .message_base import Seg
|
from .message_base import Seg
|
||||||
|
from .utils_user import get_user_nickname
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -37,8 +36,11 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
|
|||||||
|
|
||||||
def init_poolmanager(self, connections, maxsize, block=False):
|
def init_poolmanager(self, connections, maxsize, block=False):
|
||||||
self.poolmanager = urllib3.poolmanager.PoolManager(
|
self.poolmanager = urllib3.poolmanager.PoolManager(
|
||||||
num_pools=connections, maxsize=maxsize,
|
num_pools=connections,
|
||||||
block=block, ssl_context=self.ssl_context)
|
maxsize=maxsize,
|
||||||
|
block=block,
|
||||||
|
ssl_context=self.ssl_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -52,6 +54,7 @@ class CQCode:
|
|||||||
raw_code: 原始CQ码字符串
|
raw_code: 原始CQ码字符串
|
||||||
translated_segments: 经过处理后的Seg对象列表
|
translated_segments: 经过处理后的Seg对象列表
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
params: Dict[str, str]
|
params: Dict[str, str]
|
||||||
group_id: int
|
group_id: int
|
||||||
@@ -65,77 +68,52 @@ class CQCode:
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""初始化LLM实例"""
|
"""初始化LLM实例"""
|
||||||
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
|
self._llm = LLM_request(
|
||||||
|
model=global_config.vlm, temperature=0.4, max_tokens=300
|
||||||
|
)
|
||||||
|
|
||||||
def translate(self):
|
def translate(self):
|
||||||
"""根据CQ码类型进行相应的翻译处理,转换为Seg对象"""
|
"""根据CQ码类型进行相应的翻译处理,转换为Seg对象"""
|
||||||
if self.type == 'text':
|
if self.type == "text":
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(
|
||||||
type='text',
|
type="text", data=self.params.get("text", "")
|
||||||
data=self.params.get('text', '')
|
|
||||||
)
|
)
|
||||||
elif self.type == 'image':
|
elif self.type == "image":
|
||||||
base64_data = self.translate_image()
|
base64_data = self.translate_image()
|
||||||
if base64_data:
|
if base64_data:
|
||||||
if self.params.get('sub_type') == '0':
|
if self.params.get("sub_type") == "0":
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(type="image", data=base64_data)
|
||||||
type='image',
|
|
||||||
data=base64_data
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(type="emoji", data=base64_data)
|
||||||
type='emoji',
|
|
||||||
data=base64_data
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
self.translated_segments = Seg(type="text", data="[图片]")
|
||||||
|
elif self.type == "at":
|
||||||
|
user_nickname = get_user_nickname(self.params.get("qq", ""))
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(
|
||||||
type='text',
|
type="text", data=f"[@{user_nickname or '某人'}]"
|
||||||
data='[图片]'
|
|
||||||
)
|
)
|
||||||
elif self.type == 'at':
|
elif self.type == "reply":
|
||||||
user_nickname = get_user_nickname(self.params.get('qq', ''))
|
|
||||||
self.translated_segments = Seg(
|
|
||||||
type='text',
|
|
||||||
data=f"[@{user_nickname or '某人'}]"
|
|
||||||
)
|
|
||||||
elif self.type == 'reply':
|
|
||||||
reply_segments = self.translate_reply()
|
reply_segments = self.translate_reply()
|
||||||
if reply_segments:
|
if reply_segments:
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(type="seglist", data=reply_segments)
|
||||||
type='seglist',
|
|
||||||
data=reply_segments
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
|
self.translated_segments = Seg(type="text", data="[回复某人消息]")
|
||||||
|
elif self.type == "face":
|
||||||
|
face_id = self.params.get("id", "")
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(
|
||||||
type='text',
|
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
|
||||||
data='[回复某人消息]'
|
|
||||||
)
|
)
|
||||||
elif self.type == 'face':
|
elif self.type == "forward":
|
||||||
face_id = self.params.get('id', '')
|
|
||||||
self.translated_segments = Seg(
|
|
||||||
type='text',
|
|
||||||
data=f"[{emojimapper.get(int(face_id), '表情')}]"
|
|
||||||
)
|
|
||||||
elif self.type == 'forward':
|
|
||||||
forward_segments = self.translate_forward()
|
forward_segments = self.translate_forward()
|
||||||
if forward_segments:
|
if forward_segments:
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(type="seglist", data=forward_segments)
|
||||||
type='seglist',
|
|
||||||
data=forward_segments
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(type="text", data="[转发消息]")
|
||||||
type='text',
|
|
||||||
data='[转发消息]'
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.translated_segments = Seg(
|
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
|
||||||
type='text',
|
|
||||||
data=f"[{self.type}]"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_img(self):
|
def get_img(self):
|
||||||
'''
|
"""
|
||||||
headers = {
|
headers = {
|
||||||
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
|
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
|
||||||
'Accept': 'image/*;q=0.8',
|
'Accept': 'image/*;q=0.8',
|
||||||
@@ -144,18 +122,18 @@ class CQCode:
|
|||||||
'Cache-Control': 'no-cache',
|
'Cache-Control': 'no-cache',
|
||||||
'Pragma': 'no-cache'
|
'Pragma': 'no-cache'
|
||||||
}
|
}
|
||||||
'''
|
"""
|
||||||
# 腾讯专用请求头配置
|
# 腾讯专用请求头配置
|
||||||
headers = {
|
headers = {
|
||||||
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36',
|
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
|
||||||
'Accept': 'text/html, application/xhtml xml, */*',
|
"Accept": "text/html, application/xhtml xml, */*",
|
||||||
'Accept-Encoding': 'gbk, GB2312',
|
"Accept-Encoding": "gbk, GB2312",
|
||||||
'Accept-Language': 'zh-cn',
|
"Accept-Language": "zh-cn",
|
||||||
'Content-Type': 'application/x-www-form-urlencoded',
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
'Cache-Control': 'no-cache'
|
"Cache-Control": "no-cache",
|
||||||
}
|
}
|
||||||
url = html.unescape(self.params['url'])
|
url = html.unescape(self.params["url"])
|
||||||
if not url.startswith(('http://', 'https://')):
|
if not url.startswith(("http://", "https://")):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 创建专用会话
|
# 创建专用会话
|
||||||
@@ -171,30 +149,30 @@ class CQCode:
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=15,
|
timeout=15,
|
||||||
allow_redirects=True,
|
allow_redirects=True,
|
||||||
stream=True # 流式传输避免大内存问题
|
stream=True, # 流式传输避免大内存问题
|
||||||
)
|
)
|
||||||
|
|
||||||
# 腾讯服务器特殊状态码处理
|
# 腾讯服务器特殊状态码处理
|
||||||
if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url:
|
if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
|
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
|
||||||
|
|
||||||
# 验证内容类型
|
# 验证内容类型
|
||||||
content_type = response.headers.get('Content-Type', '')
|
content_type = response.headers.get("Content-Type", "")
|
||||||
if not content_type.startswith('image/'):
|
if not content_type.startswith("image/"):
|
||||||
raise ValueError(f"非图片内容类型: {content_type}")
|
raise ValueError(f"非图片内容类型: {content_type}")
|
||||||
|
|
||||||
# 转换为Base64
|
# 转换为Base64
|
||||||
image_base64 = base64.b64encode(response.content).decode('utf-8')
|
image_base64 = base64.b64encode(response.content).decode("utf-8")
|
||||||
self.image_base64 = image_base64
|
self.image_base64 = image_base64
|
||||||
return image_base64
|
return image_base64
|
||||||
|
|
||||||
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
|
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
|
||||||
if retry == max_retries - 1:
|
if retry == max_retries - 1:
|
||||||
print(f"\033[1;31m[致命错误]\033[0m 最终请求失败: {str(e)}")
|
print(f"\033[1;31m[致命错误]\033[0m 最终请求失败: {str(e)}")
|
||||||
time.sleep(1.5 ** retry) # 指数退避
|
time.sleep(1.5**retry) # 指数退避
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;33m[未知错误]\033[0m {str(e)}")
|
print(f"\033[1;33m[未知错误]\033[0m {str(e)}")
|
||||||
@@ -202,21 +180,21 @@ class CQCode:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def translate_image(self) -> Optional[str]:
|
def translate_image(self) -> Optional[str]:
|
||||||
"""处理图片类型的CQ码,返回base64字符串"""
|
"""处理图片类型的CQ码,返回base64字符串"""
|
||||||
if 'url' not in self.params:
|
if "url" not in self.params:
|
||||||
return None
|
return None
|
||||||
return self.get_img()
|
return self.get_img()
|
||||||
|
|
||||||
def translate_forward(self) -> Optional[List[Seg]]:
|
def translate_forward(self) -> Optional[List[Seg]]:
|
||||||
"""处理转发消息,返回Seg列表"""
|
"""处理转发消息,返回Seg列表"""
|
||||||
try:
|
try:
|
||||||
if 'content' not in self.params:
|
if "content" not in self.params:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
content = self.unescape(self.params['content'])
|
content = self.unescape(self.params["content"])
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
try:
|
try:
|
||||||
messages = ast.literal_eval(content)
|
messages = ast.literal_eval(content)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@@ -225,46 +203,52 @@ class CQCode:
|
|||||||
|
|
||||||
formatted_segments = []
|
formatted_segments = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
sender = msg.get('sender', {})
|
sender = msg.get("sender", {})
|
||||||
nickname = sender.get('card') or sender.get('nickname', '未知用户')
|
nickname = sender.get("card") or sender.get("nickname", "未知用户")
|
||||||
raw_message = msg.get('raw_message', '')
|
raw_message = msg.get("raw_message", "")
|
||||||
message_array = msg.get('message', [])
|
message_array = msg.get("message", [])
|
||||||
|
|
||||||
if message_array and isinstance(message_array, list):
|
if message_array and isinstance(message_array, list):
|
||||||
for message_part in message_array:
|
for message_part in message_array:
|
||||||
if message_part.get('type') == 'forward':
|
if message_part.get("type") == "forward":
|
||||||
content_seg = Seg(type='text', data='[转发消息]')
|
content_seg = Seg(type="text", data="[转发消息]")
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if raw_message:
|
if raw_message:
|
||||||
from .message_cq import MessageRecvCQ
|
from .message_cq import MessageRecvCQ
|
||||||
|
|
||||||
message_obj = MessageRecvCQ(
|
message_obj = MessageRecvCQ(
|
||||||
user_id=msg.get('user_id', 0),
|
user_id=msg.get("user_id", 0),
|
||||||
message_id=msg.get('message_id', 0),
|
message_id=msg.get("message_id", 0),
|
||||||
raw_message=raw_message,
|
raw_message=raw_message,
|
||||||
plain_text=raw_message,
|
plain_text=raw_message,
|
||||||
group_id=msg.get('group_id', 0)
|
group_id=msg.get("group_id", 0),
|
||||||
|
)
|
||||||
|
content_seg = Seg(
|
||||||
|
type="seglist", data=message_obj.message_segments
|
||||||
)
|
)
|
||||||
content_seg = Seg(type='seglist', data=message_obj.message_segments)
|
|
||||||
else:
|
else:
|
||||||
content_seg = Seg(type='text', data='[空消息]')
|
content_seg = Seg(type="text", data="[空消息]")
|
||||||
else:
|
else:
|
||||||
if raw_message:
|
if raw_message:
|
||||||
from .message_cq import MessageRecvCQ
|
from .message_cq import MessageRecvCQ
|
||||||
|
|
||||||
message_obj = MessageRecvCQ(
|
message_obj = MessageRecvCQ(
|
||||||
user_id=msg.get('user_id', 0),
|
user_id=msg.get("user_id", 0),
|
||||||
message_id=msg.get('message_id', 0),
|
message_id=msg.get("message_id", 0),
|
||||||
raw_message=raw_message,
|
raw_message=raw_message,
|
||||||
plain_text=raw_message,
|
plain_text=raw_message,
|
||||||
group_id=msg.get('group_id', 0)
|
group_id=msg.get("group_id", 0),
|
||||||
|
)
|
||||||
|
content_seg = Seg(
|
||||||
|
type="seglist", data=message_obj.message_segments
|
||||||
)
|
)
|
||||||
content_seg = Seg(type='seglist', data=message_obj.message_segments)
|
|
||||||
else:
|
else:
|
||||||
content_seg = Seg(type='text', data='[空消息]')
|
content_seg = Seg(type="text", data="[空消息]")
|
||||||
|
|
||||||
formatted_segments.append(Seg(type='text', data=f"{nickname}: "))
|
formatted_segments.append(Seg(type="text", data=f"{nickname}: "))
|
||||||
formatted_segments.append(content_seg)
|
formatted_segments.append(content_seg)
|
||||||
formatted_segments.append(Seg(type='text', data='\n'))
|
formatted_segments.append(Seg(type="text", data="\n"))
|
||||||
|
|
||||||
return formatted_segments
|
return formatted_segments
|
||||||
|
|
||||||
@@ -275,6 +259,7 @@ class CQCode:
|
|||||||
def translate_reply(self) -> Optional[List[Seg]]:
|
def translate_reply(self) -> Optional[List[Seg]]:
|
||||||
"""处理回复类型的CQ码,返回Seg列表"""
|
"""处理回复类型的CQ码,返回Seg列表"""
|
||||||
from .message_cq import MessageRecvCQ
|
from .message_cq import MessageRecvCQ
|
||||||
|
|
||||||
if self.reply_message is None:
|
if self.reply_message is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -283,17 +268,26 @@ class CQCode:
|
|||||||
user_id=self.reply_message.sender.user_id,
|
user_id=self.reply_message.sender.user_id,
|
||||||
message_id=self.reply_message.message_id,
|
message_id=self.reply_message.message_id,
|
||||||
raw_message=str(self.reply_message.message),
|
raw_message=str(self.reply_message.message),
|
||||||
group_id=self.group_id
|
group_id=self.group_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
segments = []
|
segments = []
|
||||||
if message_obj.user_id == global_config.BOT_QQ:
|
if message_obj.user_id == global_config.BOT_QQ:
|
||||||
segments.append(Seg(type='text', data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "))
|
segments.append(
|
||||||
|
Seg(
|
||||||
|
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
segments.append(Seg(type='text', data=f"[回复 {self.reply_message.sender.nickname} 的消息: "))
|
segments.append(
|
||||||
|
Seg(
|
||||||
|
type="text",
|
||||||
|
data=f"[回复 {self.reply_message.sender.nickname} 的消息: ",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
segments.append(Seg(type='seglist', data=message_obj.message_segments))
|
segments.append(Seg(type="seglist", data=message_obj.message_segments))
|
||||||
segments.append(Seg(type='text', data="]"))
|
segments.append(Seg(type="text", data="]"))
|
||||||
return segments
|
return segments
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@@ -301,12 +295,12 @@ class CQCode:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def unescape(text: str) -> str:
|
def unescape(text: str) -> str:
|
||||||
"""反转义CQ码中的特殊字符"""
|
"""反转义CQ码中的特殊字符"""
|
||||||
return text.replace(',', ',') \
|
return (
|
||||||
.replace('[', '[') \
|
text.replace(",", ",")
|
||||||
.replace(']', ']') \
|
.replace("[", "[")
|
||||||
.replace('&', '&')
|
.replace("]", "]")
|
||||||
|
.replace("&", "&")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CQCode_tool:
|
class CQCode_tool:
|
||||||
@@ -324,19 +318,15 @@ class CQCode_tool:
|
|||||||
"""
|
"""
|
||||||
# 处理字典形式的CQ码
|
# 处理字典形式的CQ码
|
||||||
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
|
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
|
||||||
cq_type = cq_code.get('type', 'text')
|
cq_type = cq_code.get("type", "text")
|
||||||
params = {}
|
params = {}
|
||||||
if cq_type == 'text':
|
if cq_type == "text":
|
||||||
params['text'] = cq_code.get('data', {}).get('text', '')
|
params["text"] = cq_code.get("data", {}).get("text", "")
|
||||||
else:
|
else:
|
||||||
params = cq_code.get('data', {})
|
params = cq_code.get("data", {})
|
||||||
|
|
||||||
instance = CQCode(
|
instance = CQCode(
|
||||||
type=cq_type,
|
type=cq_type, params=params, group_id=0, user_id=0, reply_message=reply
|
||||||
params=params,
|
|
||||||
group_id=0,
|
|
||||||
user_id=0,
|
|
||||||
reply_message=reply
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 进行翻译处理
|
# 进行翻译处理
|
||||||
@@ -366,10 +356,12 @@ class CQCode_tool:
|
|||||||
# 确保使用绝对路径
|
# 确保使用绝对路径
|
||||||
abs_path = os.path.abspath(file_path)
|
abs_path = os.path.abspath(file_path)
|
||||||
# 转义特殊字符
|
# 转义特殊字符
|
||||||
escaped_path = abs_path.replace('&', '&') \
|
escaped_path = (
|
||||||
.replace('[', '[') \
|
abs_path.replace("&", "&")
|
||||||
.replace(']', ']') \
|
.replace("[", "[")
|
||||||
.replace(',', ',')
|
.replace("]", "]")
|
||||||
|
.replace(",", ",")
|
||||||
|
)
|
||||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||||
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
||||||
|
|
||||||
@@ -383,15 +375,14 @@ class CQCode_tool:
|
|||||||
表情包CQ码字符串
|
表情包CQ码字符串
|
||||||
"""
|
"""
|
||||||
# 转义base64数据
|
# 转义base64数据
|
||||||
escaped_base64 = base64_data.replace('&', '&') \
|
escaped_base64 = (
|
||||||
.replace('[', '[') \
|
base64_data.replace("&", "&")
|
||||||
.replace(']', ']') \
|
.replace("[", "[")
|
||||||
.replace(',', ',')
|
.replace("]", "]")
|
||||||
|
.replace(",", ",")
|
||||||
|
)
|
||||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||||
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
|
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cq_code_tool = CQCode_tool()
|
cq_code_tool = CQCode_tool()
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
@@ -13,9 +13,8 @@ from nonebot import get_driver
|
|||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
from ..chat.utils import get_embedding
|
from ..chat.utils import get_embedding
|
||||||
from ..chat.utils_image import image_path_to_base64
|
from ..chat.utils_image import ImageManager, image_path_to_base64
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from ..chat.utils_image import ImageManager
|
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -78,7 +77,6 @@ class EmojiManager:
|
|||||||
if 'emoji' not in self.db.db.list_collection_names():
|
if 'emoji' not in self.db.db.list_collection_names():
|
||||||
self.db.db.create_collection('emoji')
|
self.db.db.create_collection('emoji')
|
||||||
self.db.db.emoji.create_index([('embedding', '2dsphere')])
|
self.db.db.emoji.create_index([('embedding', '2dsphere')])
|
||||||
self.db.db.emoji.create_index([('tags', 1)])
|
|
||||||
self.db.db.emoji.create_index([('filename', 1)], unique=True)
|
self.db.db.emoji.create_index([('filename', 1)], unique=True)
|
||||||
|
|
||||||
def record_usage(self, emoji_id: str):
|
def record_usage(self, emoji_id: str):
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from nonebot import get_driver
|
|||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .message_cq import Message
|
from .message import MessageRecv, MessageThinking, MessageSending
|
||||||
from .prompt_builder import prompt_builder
|
from .prompt_builder import prompt_builder
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from .utils import process_llm_response
|
from .utils import process_llm_response
|
||||||
@@ -18,48 +18,78 @@ config = driver.config
|
|||||||
|
|
||||||
class ResponseGenerator:
|
class ResponseGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000,stream=True)
|
self.model_r1 = LLM_request(
|
||||||
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000)
|
model=global_config.llm_reasoning,
|
||||||
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000)
|
temperature=0.7,
|
||||||
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000)
|
max_tokens=1000,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
self.model_v3 = LLM_request(
|
||||||
|
model=global_config.llm_normal, temperature=0.7, max_tokens=1000
|
||||||
|
)
|
||||||
|
self.model_r1_distill = LLM_request(
|
||||||
|
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
|
||||||
|
)
|
||||||
|
self.model_v25 = LLM_request(
|
||||||
|
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
|
||||||
|
)
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
self.current_model_type = 'r1' # 默认使用 R1
|
self.current_model_type = "r1" # 默认使用 R1
|
||||||
|
|
||||||
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
|
async def generate_response(
|
||||||
|
self, message: MessageThinking
|
||||||
|
) -> Optional[Union[str, List[str]]]:
|
||||||
"""根据当前模型类型选择对应的生成函数"""
|
"""根据当前模型类型选择对应的生成函数"""
|
||||||
# 从global_config中获取模型概率值并选择模型
|
# 从global_config中获取模型概率值并选择模型
|
||||||
rand = random.random()
|
rand = random.random()
|
||||||
if rand < global_config.MODEL_R1_PROBABILITY:
|
if rand < global_config.MODEL_R1_PROBABILITY:
|
||||||
self.current_model_type = 'r1'
|
self.current_model_type = "r1"
|
||||||
current_model = self.model_r1
|
current_model = self.model_r1
|
||||||
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
|
elif (
|
||||||
self.current_model_type = 'v3'
|
rand
|
||||||
|
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
|
||||||
|
):
|
||||||
|
self.current_model_type = "v3"
|
||||||
current_model = self.model_v3
|
current_model = self.model_v3
|
||||||
else:
|
else:
|
||||||
self.current_model_type = 'r1_distill'
|
self.current_model_type = "r1_distill"
|
||||||
current_model = self.model_r1_distill
|
current_model = self.model_r1_distill
|
||||||
|
|
||||||
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
|
print(
|
||||||
|
f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++"
|
||||||
|
)
|
||||||
|
|
||||||
model_response = await self._generate_response_with_model(message, current_model)
|
model_response = await self._generate_response_with_model(
|
||||||
raw_content=model_response
|
message, current_model
|
||||||
|
)
|
||||||
|
raw_content = model_response
|
||||||
|
|
||||||
if model_response:
|
if model_response:
|
||||||
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
|
print(f"{global_config.BOT_NICKNAME}的回复是:{model_response}")
|
||||||
model_response = await self._process_response(model_response)
|
model_response = await self._process_response(model_response)
|
||||||
if model_response:
|
if model_response:
|
||||||
|
return model_response, raw_content
|
||||||
|
return None, raw_content
|
||||||
|
|
||||||
return model_response ,raw_content
|
async def _generate_response_with_model(
|
||||||
return None,raw_content
|
self, message: MessageThinking, model: LLM_request
|
||||||
|
) -> Optional[str]:
|
||||||
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
|
|
||||||
"""使用指定的模型生成回复"""
|
"""使用指定的模型生成回复"""
|
||||||
sender_name = message.user_nickname or f"用户{message.user_id}"
|
sender_name = (
|
||||||
if message.user_cardname:
|
message.chat_stream.user_info.user_nickname
|
||||||
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}"
|
or f"用户{message.chat_stream.user_info.user_id}"
|
||||||
|
)
|
||||||
|
if message.chat_stream.user_info.user_cardname:
|
||||||
|
sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}"
|
||||||
|
|
||||||
# 获取关系值
|
# 获取关系值
|
||||||
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
|
relationship_value = (
|
||||||
|
relationship_manager.get_relationship(
|
||||||
|
message.chat_stream
|
||||||
|
).relationship_value
|
||||||
|
if relationship_manager.get_relationship(message.chat_stream)
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
if relationship_value != 0.0:
|
if relationship_value != 0.0:
|
||||||
# print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
# print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
||||||
pass
|
pass
|
||||||
@@ -69,7 +99,7 @@ class ResponseGenerator:
|
|||||||
message_txt=message.processed_plain_text,
|
message_txt=message.processed_plain_text,
|
||||||
sender_name=sender_name,
|
sender_name=sender_name,
|
||||||
relationship_value=relationship_value,
|
relationship_value=relationship_value,
|
||||||
group_id=message.group_id
|
stream_id=message.chat_stream.stream_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 读空气模块 简化逻辑,先停用
|
# 读空气模块 简化逻辑,先停用
|
||||||
@@ -112,34 +142,51 @@ class ResponseGenerator:
|
|||||||
|
|
||||||
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
|
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
|
||||||
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
|
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
|
||||||
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
|
def _save_to_db(
|
||||||
content: str, reasoning_content: str,):
|
self,
|
||||||
|
message: Message,
|
||||||
|
sender_name: str,
|
||||||
|
prompt: str,
|
||||||
|
prompt_check: str,
|
||||||
|
content: str,
|
||||||
|
reasoning_content: str,
|
||||||
|
):
|
||||||
"""保存对话记录到数据库"""
|
"""保存对话记录到数据库"""
|
||||||
self.db.db.reasoning_logs.insert_one({
|
self.db.db.reasoning_logs.insert_one(
|
||||||
'time': time.time(),
|
{
|
||||||
'group_id': message.group_id,
|
"time": time.time(),
|
||||||
'user': sender_name,
|
"group_id": message.group_id,
|
||||||
'message': message.processed_plain_text,
|
"user": sender_name,
|
||||||
'model': self.current_model_type,
|
"message": message.processed_plain_text,
|
||||||
|
"model": self.current_model_type,
|
||||||
# 'reasoning_check': reasoning_content_check,
|
# 'reasoning_check': reasoning_content_check,
|
||||||
# 'response_check': content_check,
|
# 'response_check': content_check,
|
||||||
'reasoning': reasoning_content,
|
"reasoning": reasoning_content,
|
||||||
'response': content,
|
"response": content,
|
||||||
'prompt': prompt,
|
"prompt": prompt,
|
||||||
'prompt_check': prompt_check
|
"prompt_check": prompt_check,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
async def _get_emotion_tags(self, content: str) -> List[str]:
|
async def _get_emotion_tags(self, content: str) -> List[str]:
|
||||||
"""提取情感标签"""
|
"""提取情感标签"""
|
||||||
try:
|
try:
|
||||||
prompt = f'''请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
|
prompt = f"""请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
|
||||||
只输出标签就好,不要输出其他内容:
|
只输出标签就好,不要输出其他内容:
|
||||||
内容:{content}
|
内容:{content}
|
||||||
输出:
|
输出:
|
||||||
'''
|
"""
|
||||||
content, _ = await self.model_v25.generate_response(prompt)
|
content, _ = await self.model_v25.generate_response(prompt)
|
||||||
content=content.strip()
|
content = content.strip()
|
||||||
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']:
|
if content in [
|
||||||
|
"happy",
|
||||||
|
"angry",
|
||||||
|
"sad",
|
||||||
|
"surprised",
|
||||||
|
"disgusted",
|
||||||
|
"fearful",
|
||||||
|
"neutral",
|
||||||
|
]:
|
||||||
return [content]
|
return [content]
|
||||||
else:
|
else:
|
||||||
return ["neutral"]
|
return ["neutral"]
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from typing import Dict, ForwardRef, List, Optional, Union
|
|||||||
import urllib3
|
import urllib3
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
|
|
||||||
from .utils_image import image_manager
|
from .utils_image import image_manager
|
||||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
@@ -110,23 +109,30 @@ class MessageRecv(MessageBase):
|
|||||||
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
|
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageProcessBase(MessageBase):
|
class Message(MessageBase):
|
||||||
"""消息处理基类,用于处理中和发送中的消息"""
|
chat_stream: ChatStream=None
|
||||||
|
reply: Optional['Message'] = None
|
||||||
|
detailed_plain_text: str = ""
|
||||||
|
processed_plain_text: str = ""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
|
time: int,
|
||||||
chat_stream: ChatStream,
|
chat_stream: ChatStream,
|
||||||
|
user_info: UserInfo,
|
||||||
message_segment: Optional[Seg] = None,
|
message_segment: Optional[Seg] = None,
|
||||||
reply: Optional['MessageRecv'] = None
|
reply: Optional['MessageRecv'] = None,
|
||||||
|
detailed_plain_text: str = "",
|
||||||
|
processed_plain_text: str = "",
|
||||||
):
|
):
|
||||||
# 构造基础消息信息
|
# 构造基础消息信息
|
||||||
message_info = BaseMessageInfo(
|
message_info = BaseMessageInfo(
|
||||||
platform=chat_stream.platform,
|
platform=chat_stream.platform,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
time=int(time.time()),
|
time=time,
|
||||||
group_info=chat_stream.group_info,
|
group_info=chat_stream.group_info,
|
||||||
user_info=chat_stream.user_info
|
user_info=user_info
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
@@ -136,17 +142,41 @@ class MessageProcessBase(MessageBase):
|
|||||||
raw_message=None
|
raw_message=None
|
||||||
)
|
)
|
||||||
|
|
||||||
# 处理状态相关属性
|
self.chat_stream = chat_stream
|
||||||
self.thinking_start_time = int(time.time())
|
|
||||||
self.thinking_time = 0
|
|
||||||
|
|
||||||
# 文本处理相关属性
|
# 文本处理相关属性
|
||||||
self.processed_plain_text = ""
|
self.processed_plain_text = detailed_plain_text
|
||||||
self.detailed_plain_text = ""
|
self.detailed_plain_text = processed_plain_text
|
||||||
|
|
||||||
# 回复消息
|
# 回复消息
|
||||||
self.reply = reply
|
self.reply = reply
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MessageProcessBase(Message):
|
||||||
|
"""消息处理基类,用于处理中和发送中的消息"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_id: str,
|
||||||
|
chat_stream: ChatStream,
|
||||||
|
bot_user_info: UserInfo,
|
||||||
|
message_segment: Optional[Seg] = None,
|
||||||
|
reply: Optional['MessageRecv'] = None
|
||||||
|
):
|
||||||
|
# 调用父类初始化
|
||||||
|
super().__init__(
|
||||||
|
message_id=message_id,
|
||||||
|
time=int(time.time()),
|
||||||
|
chat_stream=chat_stream,
|
||||||
|
user_info=bot_user_info,
|
||||||
|
message_segment=message_segment,
|
||||||
|
reply=reply
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理状态相关属性
|
||||||
|
self.thinking_start_time = int(time.time())
|
||||||
|
self.thinking_time = 0
|
||||||
|
|
||||||
def update_thinking_time(self) -> float:
|
def update_thinking_time(self) -> float:
|
||||||
"""更新思考时间"""
|
"""更新思考时间"""
|
||||||
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
|
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
|
||||||
@@ -224,12 +254,14 @@ class MessageThinking(MessageProcessBase):
|
|||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
chat_stream: ChatStream,
|
chat_stream: ChatStream,
|
||||||
|
bot_user_info: UserInfo,
|
||||||
reply: Optional['MessageRecv'] = None
|
reply: Optional['MessageRecv'] = None
|
||||||
):
|
):
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream,
|
||||||
|
bot_user_info=bot_user_info,
|
||||||
message_segment=None, # 思考状态不需要消息段
|
message_segment=None, # 思考状态不需要消息段
|
||||||
reply=reply
|
reply=reply
|
||||||
)
|
)
|
||||||
@@ -237,15 +269,6 @@ class MessageThinking(MessageProcessBase):
|
|||||||
# 思考状态特有属性
|
# 思考状态特有属性
|
||||||
self.interrupt = False
|
self.interrupt = False
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_chat_stream(cls, chat_stream: ChatStream, message_id: str, reply: Optional['MessageRecv'] = None) -> 'MessageThinking':
|
|
||||||
"""从聊天流创建思考状态消息"""
|
|
||||||
return cls(
|
|
||||||
message_id=message_id,
|
|
||||||
chat_stream=chat_stream,
|
|
||||||
reply=reply
|
|
||||||
)
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageSending(MessageProcessBase):
|
class MessageSending(MessageProcessBase):
|
||||||
"""发送状态的消息类"""
|
"""发送状态的消息类"""
|
||||||
@@ -254,6 +277,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
chat_stream: ChatStream,
|
chat_stream: ChatStream,
|
||||||
|
bot_user_info: UserInfo,
|
||||||
message_segment: Seg,
|
message_segment: Seg,
|
||||||
reply: Optional['MessageRecv'] = None,
|
reply: Optional['MessageRecv'] = None,
|
||||||
is_head: bool = False,
|
is_head: bool = False,
|
||||||
@@ -263,6 +287,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
super().__init__(
|
super().__init__(
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
chat_stream=chat_stream,
|
chat_stream=chat_stream,
|
||||||
|
bot_user_info=bot_user_info,
|
||||||
message_segment=message_segment,
|
message_segment=message_segment,
|
||||||
reply=reply
|
reply=reply
|
||||||
)
|
)
|
||||||
@@ -296,11 +321,17 @@ class MessageSending(MessageProcessBase):
|
|||||||
message_id=thinking.message_info.message_id,
|
message_id=thinking.message_info.message_id,
|
||||||
chat_stream=thinking.chat_stream,
|
chat_stream=thinking.chat_stream,
|
||||||
message_segment=message_segment,
|
message_segment=message_segment,
|
||||||
|
bot_user_info=thinking.message_info.user_info,
|
||||||
reply=thinking.reply,
|
reply=thinking.reply,
|
||||||
is_head=is_head,
|
is_head=is_head,
|
||||||
is_emoji=is_emoji
|
is_emoji=is_emoji
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
ret= super().to_dict()
|
||||||
|
ret['mesage_info']['user_info']=self.chat_stream.user_info.to_dict()
|
||||||
|
return ret
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageSet:
|
class MessageSet:
|
||||||
"""消息集合类,可以存储多个发送消息"""
|
"""消息集合类,可以存储多个发送消息"""
|
||||||
|
|||||||
@@ -79,6 +79,21 @@ class GroupInfo:
|
|||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
def from_dict(cls, data: Dict) -> 'GroupInfo':
|
||||||
|
"""从字典创建GroupInfo实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含必要字段的字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GroupInfo: 新的实例
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
platform=data.get('platform'),
|
||||||
|
group_id=data.get('group_id'),
|
||||||
|
group_name=data.get('group_name',None)
|
||||||
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UserInfo:
|
class UserInfo:
|
||||||
"""用户信息类"""
|
"""用户信息类"""
|
||||||
@@ -91,6 +106,22 @@ class UserInfo:
|
|||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
def from_dict(cls, data: Dict) -> 'UserInfo':
|
||||||
|
"""从字典创建UserInfo实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含必要字段的字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserInfo: 新的实例
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
platform=data.get('platform'),
|
||||||
|
user_id=data.get('user_id'),
|
||||||
|
user_nickname=data.get('user_nickname',None),
|
||||||
|
user_cardname=data.get('user_cardname',None)
|
||||||
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseMessageInfo:
|
class BaseMessageInfo:
|
||||||
"""消息信息类"""
|
"""消息信息类"""
|
||||||
@@ -147,7 +178,7 @@ class MessageBase:
|
|||||||
"""
|
"""
|
||||||
message_info = BaseMessageInfo(**data.get('message_info', {}))
|
message_info = BaseMessageInfo(**data.get('message_info', {}))
|
||||||
message_segment = Seg(**data.get('message_segment', {}))
|
message_segment = Seg(**data.get('message_segment', {}))
|
||||||
raw_message = data.get('raw_message')
|
raw_message = data.get('raw_message',None)
|
||||||
return cls(
|
return cls(
|
||||||
message_info=message_info,
|
message_info=message_info,
|
||||||
message_segment=message_segment,
|
message_segment=message_segment,
|
||||||
|
|||||||
@@ -139,27 +139,24 @@ class MessageSendCQ(MessageCQ):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: int,
|
data: Dict
|
||||||
user_id: int,
|
|
||||||
message_segment: Seg,
|
|
||||||
group_id: Optional[int] = None,
|
|
||||||
reply_to_message_id: Optional[int] = None,
|
|
||||||
platform: str = "qq"
|
|
||||||
):
|
):
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(message_id, user_id, group_id, platform)
|
message_info = BaseMessageInfo(**data.get('message_info', {}))
|
||||||
|
message_segment = Seg(**data.get('message_segment', {}))
|
||||||
|
super().__init__(
|
||||||
|
message_info.message_id,
|
||||||
|
message_info.user_info.user_id,
|
||||||
|
message_info.group_info.group_id if message_info.group_info else None,
|
||||||
|
message_info.platform)
|
||||||
|
|
||||||
self.message_segment = message_segment
|
self.message_segment = message_segment
|
||||||
self.raw_message = self._generate_raw_message(reply_to_message_id)
|
self.raw_message = self._generate_raw_message()
|
||||||
|
|
||||||
def _generate_raw_message(self, reply_to_message_id: Optional[int] = None) -> str:
|
def _generate_raw_message(self, ) -> str:
|
||||||
"""将Seg对象转换为raw_message"""
|
"""将Seg对象转换为raw_message"""
|
||||||
segments = []
|
segments = []
|
||||||
|
|
||||||
# 添加回复消息
|
|
||||||
if reply_to_message_id:
|
|
||||||
segments.append(cq_code_tool.create_reply_cq(reply_to_message_id))
|
|
||||||
|
|
||||||
# 处理消息段
|
# 处理消息段
|
||||||
if self.message_segment.type == 'seglist':
|
if self.message_segment.type == 'seglist':
|
||||||
for seg in self.message_segment.data:
|
for seg in self.message_segment.data:
|
||||||
|
|||||||
@@ -29,13 +29,12 @@ class Message_Sender:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""发送消息"""
|
"""发送消息"""
|
||||||
if isinstance(message, MessageSending):
|
if isinstance(message, MessageSending):
|
||||||
|
message_json = message.to_dict()
|
||||||
message_send=MessageSendCQ(
|
message_send=MessageSendCQ(
|
||||||
message_id=message.message_id,
|
data=message_json
|
||||||
user_id=message.message_info.user_info.user_id,
|
|
||||||
message_segment=message.message_segment,
|
|
||||||
reply=message.reply
|
|
||||||
)
|
)
|
||||||
if message.message_info.group_info:
|
|
||||||
|
if message_send.message_info.group_info:
|
||||||
try:
|
try:
|
||||||
await self._current_bot.send_group_msg(
|
await self._current_bot.send_group_msg(
|
||||||
group_id=message.message_info.group_info.group_id,
|
group_id=message.message_info.group_info.group_id,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from ..moods.moods import MoodManager
|
|||||||
from ..schedule.schedule_generator import bot_schedule
|
from ..schedule.schedule_generator import bot_schedule
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .utils import get_embedding, get_recent_group_detailed_plain_text
|
from .utils import get_embedding, get_recent_group_detailed_plain_text
|
||||||
|
from .chat_stream import ChatStream, chat_manager
|
||||||
|
|
||||||
|
|
||||||
class PromptBuilder:
|
class PromptBuilder:
|
||||||
@@ -22,7 +23,7 @@ class PromptBuilder:
|
|||||||
message_txt: str,
|
message_txt: str,
|
||||||
sender_name: str = "某人",
|
sender_name: str = "某人",
|
||||||
relationship_value: float = 0.0,
|
relationship_value: float = 0.0,
|
||||||
group_id: Optional[int] = None) -> tuple[str, str]:
|
stream_id: Optional[int] = None) -> tuple[str, str]:
|
||||||
"""构建prompt
|
"""构建prompt
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -72,11 +73,17 @@ class PromptBuilder:
|
|||||||
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒")
|
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}秒")
|
||||||
|
|
||||||
# 获取聊天上下文
|
# 获取聊天上下文
|
||||||
|
chat_in_group=True
|
||||||
chat_talking_prompt = ''
|
chat_talking_prompt = ''
|
||||||
if group_id:
|
if stream_id:
|
||||||
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
|
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
|
||||||
|
chat_stream=chat_manager.get_stream(stream_id)
|
||||||
|
if chat_stream.group_info:
|
||||||
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
||||||
|
else:
|
||||||
|
chat_in_group=False
|
||||||
|
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
|
||||||
|
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -112,8 +119,10 @@ class PromptBuilder:
|
|||||||
|
|
||||||
#激活prompt构建
|
#激活prompt构建
|
||||||
activate_prompt = ''
|
activate_prompt = ''
|
||||||
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}。"
|
if chat_in_group:
|
||||||
|
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}。"
|
||||||
|
else:
|
||||||
|
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}。"
|
||||||
#检测机器人相关词汇
|
#检测机器人相关词汇
|
||||||
bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
|
bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
|
||||||
is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
|
is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
|
||||||
@@ -129,16 +138,20 @@ class PromptBuilder:
|
|||||||
probability_3 = global_config.PERSONALITY_3
|
probability_3 = global_config.PERSONALITY_3
|
||||||
prompt_personality = ''
|
prompt_personality = ''
|
||||||
personality_choice = random.random()
|
personality_choice = random.random()
|
||||||
|
if chat_in_group:
|
||||||
|
prompt_in_group=f"你正在浏览{chat_stream.platform}群"
|
||||||
|
else:
|
||||||
|
prompt_in_group=f"你正在{chat_stream.platform}上和{sender_name}私聊"
|
||||||
if personality_choice < probability_1: # 第一种人格
|
if personality_choice < probability_1: # 第一种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[0]},{prompt_in_group},{promt_info_prompt},
|
||||||
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
|
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{is_bot_prompt}
|
||||||
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
|
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
|
||||||
elif personality_choice < probability_1 + probability_2: # 第二种人格
|
elif personality_choice < probability_1 + probability_2: # 第二种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]}, 你正在浏览qq群,{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[1]},{prompt_in_group},{promt_info_prompt},
|
||||||
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
|
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
|
||||||
请你表达自己的见解和观点。可以有个性。'''
|
请你表达自己的见解和观点。可以有个性。'''
|
||||||
else: # 第三种人格
|
else: # 第三种人格
|
||||||
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[2]}, 你正在浏览qq群,{promt_info_prompt},
|
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},{personality[2]},{prompt_in_group},{promt_info_prompt},
|
||||||
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
|
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{is_bot_prompt}
|
||||||
请你表达自己的见解和观点。可以有个性。'''
|
请你表达自己的见解和观点。可以有个性。'''
|
||||||
|
|
||||||
|
|||||||
@@ -16,17 +16,16 @@ class Impression:
|
|||||||
class Relationship:
|
class Relationship:
|
||||||
user_id: int = None
|
user_id: int = None
|
||||||
platform: str = None
|
platform: str = None
|
||||||
platform: str = None
|
|
||||||
gender: str = None
|
gender: str = None
|
||||||
age: int = None
|
age: int = None
|
||||||
nickname: str = None
|
nickname: str = None
|
||||||
relationship_value: float = None
|
relationship_value: float = None
|
||||||
saved = False
|
saved = False
|
||||||
|
|
||||||
def __init__(self, chat:ChatStream,data:dict):
|
def __init__(self, chat:ChatStream=None,data:dict=None):
|
||||||
self.user_id=chat.user_info.user_id
|
self.user_id=chat.user_info.user_id if chat.user_info else data.get('user_id',0)
|
||||||
self.platform=chat.platform
|
self.platform=chat.platform if chat.user_info else data.get('platform','')
|
||||||
self.nickname=chat.user_info.user_nickname
|
self.nickname=chat.user_info.user_nickname if chat.user_info else data.get('nickname','')
|
||||||
self.relationship_value=data.get('relationship_value',0)
|
self.relationship_value=data.get('relationship_value',0)
|
||||||
self.age=data.get('age',0)
|
self.age=data.get('age',0)
|
||||||
self.gender=data.get('gender','')
|
self.gender=data.get('gender','')
|
||||||
@@ -35,7 +34,6 @@ class Relationship:
|
|||||||
class RelationshipManager:
|
class RelationshipManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
||||||
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
|
||||||
|
|
||||||
async def update_relationship(self,
|
async def update_relationship(self,
|
||||||
chat_stream:ChatStream,
|
chat_stream:ChatStream,
|
||||||
@@ -43,9 +41,7 @@ class RelationshipManager:
|
|||||||
**kwargs) -> Optional[Relationship]:
|
**kwargs) -> Optional[Relationship]:
|
||||||
"""更新或创建关系
|
"""更新或创建关系
|
||||||
Args:
|
Args:
|
||||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
chat_stream: 聊天流对象
|
||||||
platform: 平台(可选,如果提供user_info则不需要)
|
|
||||||
user_info: 用户信息对象(可选)
|
|
||||||
data: 字典格式的数据(可选)
|
data: 字典格式的数据(可选)
|
||||||
**kwargs: 其他参数
|
**kwargs: 其他参数
|
||||||
Returns:
|
Returns:
|
||||||
@@ -66,44 +62,18 @@ class RelationshipManager:
|
|||||||
|
|
||||||
# 检查是否在内存中已存在
|
# 检查是否在内存中已存在
|
||||||
relationship = self.relationships.get(key)
|
relationship = self.relationships.get(key)
|
||||||
relationship = self.relationships.get(key)
|
|
||||||
if relationship:
|
if relationship:
|
||||||
# 如果存在,更新现有对象
|
# 如果存在,更新现有对象
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
for k, value in data.items():
|
for k, value in data.items():
|
||||||
if hasattr(relationship, k) and value is not None:
|
if hasattr(relationship, k) and value is not None:
|
||||||
setattr(relationship, k, value)
|
setattr(relationship, k, value)
|
||||||
for k, value in data.items():
|
|
||||||
if hasattr(relationship, k) and value is not None:
|
|
||||||
setattr(relationship, k, value)
|
|
||||||
else:
|
|
||||||
for k, value in kwargs.items():
|
|
||||||
if hasattr(relationship, k) and value is not None:
|
|
||||||
setattr(relationship, k, value)
|
|
||||||
for k, value in kwargs.items():
|
|
||||||
if hasattr(relationship, k) and value is not None:
|
|
||||||
setattr(relationship, k, value)
|
|
||||||
else:
|
else:
|
||||||
# 如果不存在,创建新对象
|
# 如果不存在,创建新对象
|
||||||
if user_info is not None:
|
if chat_stream.user_info is not None:
|
||||||
relationship = Relationship(user_info=user_info, **kwargs)
|
relationship = Relationship(chat=chat_stream, **kwargs)
|
||||||
elif isinstance(data, dict):
|
|
||||||
data['platform'] = platform
|
|
||||||
relationship = Relationship(user_id=user_id, data=data)
|
|
||||||
else:
|
else:
|
||||||
kwargs['platform'] = platform
|
raise ValueError("必须提供user_id或user_info")
|
||||||
kwargs['user_id'] = user_id
|
|
||||||
relationship = Relationship(**kwargs)
|
|
||||||
self.relationships[key] = relationship
|
|
||||||
if user_info is not None:
|
|
||||||
relationship = Relationship(user_info=user_info, **kwargs)
|
|
||||||
elif isinstance(data, dict):
|
|
||||||
data['platform'] = platform
|
|
||||||
relationship = Relationship(user_id=user_id, data=data)
|
|
||||||
else:
|
|
||||||
kwargs['platform'] = platform
|
|
||||||
kwargs['user_id'] = user_id
|
|
||||||
relationship = Relationship(**kwargs)
|
|
||||||
self.relationships[key] = relationship
|
self.relationships[key] = relationship
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
@@ -113,36 +83,7 @@ class RelationshipManager:
|
|||||||
return relationship
|
return relationship
|
||||||
|
|
||||||
async def update_relationship_value(self,
|
async def update_relationship_value(self,
|
||||||
user_id: int = None,
|
chat_stream:ChatStream,
|
||||||
platform: str = None,
|
|
||||||
user_info: UserInfo = None,
|
|
||||||
**kwargs) -> Optional[Relationship]:
|
|
||||||
"""更新关系值
|
|
||||||
Args:
|
|
||||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
|
||||||
platform: 平台(可选,如果提供user_info则不需要)
|
|
||||||
user_info: 用户信息对象(可选)
|
|
||||||
**kwargs: 其他参数
|
|
||||||
Returns:
|
|
||||||
Relationship: 关系对象
|
|
||||||
"""
|
|
||||||
# 确定user_id和platform
|
|
||||||
if user_info is not None:
|
|
||||||
user_id = user_info.user_id
|
|
||||||
platform = user_info.platform or 'qq'
|
|
||||||
else:
|
|
||||||
platform = platform or 'qq'
|
|
||||||
|
|
||||||
if user_id is None:
|
|
||||||
raise ValueError("必须提供user_id或user_info")
|
|
||||||
|
|
||||||
# 使用(user_id, platform)作为键
|
|
||||||
key = (user_id, platform)
|
|
||||||
|
|
||||||
async def update_relationship_value(self,
|
|
||||||
user_id: int = None,
|
|
||||||
platform: str = None,
|
|
||||||
user_info: UserInfo = None,
|
|
||||||
**kwargs) -> Optional[Relationship]:
|
**kwargs) -> Optional[Relationship]:
|
||||||
"""更新关系值
|
"""更新关系值
|
||||||
Args:
|
Args:
|
||||||
@@ -154,6 +95,7 @@ class RelationshipManager:
|
|||||||
Relationship: 关系对象
|
Relationship: 关系对象
|
||||||
"""
|
"""
|
||||||
# 确定user_id和platform
|
# 确定user_id和platform
|
||||||
|
user_info = chat_stream.user_info
|
||||||
if user_info is not None:
|
if user_info is not None:
|
||||||
user_id = user_info.user_id
|
user_id = user_info.user_id
|
||||||
platform = user_info.platform or 'qq'
|
platform = user_info.platform or 'qq'
|
||||||
@@ -168,10 +110,7 @@ class RelationshipManager:
|
|||||||
|
|
||||||
# 检查是否在内存中已存在
|
# 检查是否在内存中已存在
|
||||||
relationship = self.relationships.get(key)
|
relationship = self.relationships.get(key)
|
||||||
relationship = self.relationships.get(key)
|
|
||||||
if relationship:
|
if relationship:
|
||||||
for k, value in kwargs.items():
|
|
||||||
if k == 'relationship_value':
|
|
||||||
for k, value in kwargs.items():
|
for k, value in kwargs.items():
|
||||||
if k == 'relationship_value':
|
if k == 'relationship_value':
|
||||||
relationship.relationship_value += value
|
relationship.relationship_value += value
|
||||||
@@ -181,43 +120,12 @@ class RelationshipManager:
|
|||||||
else:
|
else:
|
||||||
# 如果不存在且提供了user_info,则创建新的关系
|
# 如果不存在且提供了user_info,则创建新的关系
|
||||||
if user_info is not None:
|
if user_info is not None:
|
||||||
return await self.update_relationship(user_info=user_info, **kwargs)
|
return await self.update_relationship(chat_stream=chat_stream, **kwargs)
|
||||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
|
||||||
# 如果不存在且提供了user_info,则创建新的关系
|
|
||||||
if user_info is not None:
|
|
||||||
return await self.update_relationship(user_info=user_info, **kwargs)
|
|
||||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_relationship(self,
|
def get_relationship(self,
|
||||||
user_id: int = None,
|
chat_stream:ChatStream) -> Optional[Relationship]:
|
||||||
platform: str = None,
|
|
||||||
user_info: UserInfo = None) -> Optional[Relationship]:
|
|
||||||
"""获取用户关系对象
|
|
||||||
Args:
|
|
||||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
|
||||||
platform: 平台(可选,如果提供user_info则不需要)
|
|
||||||
user_info: 用户信息对象(可选)
|
|
||||||
Returns:
|
|
||||||
Relationship: 关系对象
|
|
||||||
"""
|
|
||||||
# 确定user_id和platform
|
|
||||||
if user_info is not None:
|
|
||||||
user_id = user_info.user_id
|
|
||||||
platform = user_info.platform or 'qq'
|
|
||||||
else:
|
|
||||||
platform = platform or 'qq'
|
|
||||||
|
|
||||||
if user_id is None:
|
|
||||||
raise ValueError("必须提供user_id或user_info")
|
|
||||||
|
|
||||||
key = (user_id, platform)
|
|
||||||
if key in self.relationships:
|
|
||||||
return self.relationships[key]
|
|
||||||
def get_relationship(self,
|
|
||||||
user_id: int = None,
|
|
||||||
platform: str = None,
|
|
||||||
user_info: UserInfo = None) -> Optional[Relationship]:
|
|
||||||
"""获取用户关系对象
|
"""获取用户关系对象
|
||||||
Args:
|
Args:
|
||||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||||
@@ -227,6 +135,8 @@ class RelationshipManager:
|
|||||||
Relationship: 关系对象
|
Relationship: 关系对象
|
||||||
"""
|
"""
|
||||||
# 确定user_id和platform
|
# 确定user_id和platform
|
||||||
|
user_info = chat_stream.user_info
|
||||||
|
platform = chat_stream.user_info.platform or 'qq'
|
||||||
if user_info is not None:
|
if user_info is not None:
|
||||||
user_id = user_info.user_id
|
user_id = user_info.user_id
|
||||||
platform = user_info.platform or 'qq'
|
platform = user_info.platform or 'qq'
|
||||||
@@ -248,18 +158,10 @@ class RelationshipManager:
|
|||||||
if 'platform' not in data:
|
if 'platform' not in data:
|
||||||
data['platform'] = 'qq'
|
data['platform'] = 'qq'
|
||||||
|
|
||||||
rela = Relationship(data=data)
|
|
||||||
"""从数据库加载或创建新的关系对象"""
|
|
||||||
# 确保data中有platform字段,如果没有则默认为'qq'
|
|
||||||
if 'platform' not in data:
|
|
||||||
data['platform'] = 'qq'
|
|
||||||
|
|
||||||
rela = Relationship(data=data)
|
rela = Relationship(data=data)
|
||||||
rela.saved = True
|
rela.saved = True
|
||||||
key = (rela.user_id, rela.platform)
|
key = (rela.user_id, rela.platform)
|
||||||
self.relationships[key] = rela
|
self.relationships[key] = rela
|
||||||
key = (rela.user_id, rela.platform)
|
|
||||||
self.relationships[key] = rela
|
|
||||||
return rela
|
return rela
|
||||||
|
|
||||||
async def load_all_relationships(self):
|
async def load_all_relationships(self):
|
||||||
@@ -277,7 +179,6 @@ class RelationshipManager:
|
|||||||
# 依次加载每条记录
|
# 依次加载每条记录
|
||||||
for data in all_relationships:
|
for data in all_relationships:
|
||||||
await self.load_relationship(data)
|
await self.load_relationship(data)
|
||||||
await self.load_relationship(data)
|
|
||||||
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -288,19 +189,15 @@ class RelationshipManager:
|
|||||||
async def _save_all_relationships(self):
|
async def _save_all_relationships(self):
|
||||||
"""将所有关系数据保存到数据库"""
|
"""将所有关系数据保存到数据库"""
|
||||||
# 保存所有关系数据
|
# 保存所有关系数据
|
||||||
for (userid, platform), relationship in self.relationships.items():
|
|
||||||
for (userid, platform), relationship in self.relationships.items():
|
for (userid, platform), relationship in self.relationships.items():
|
||||||
if not relationship.saved:
|
if not relationship.saved:
|
||||||
relationship.saved = True
|
relationship.saved = True
|
||||||
await self.storage_relationship(relationship)
|
await self.storage_relationship(relationship)
|
||||||
|
|
||||||
async def storage_relationship(self, relationship: Relationship):
|
|
||||||
"""将关系记录存储到数据库中"""
|
|
||||||
async def storage_relationship(self, relationship: Relationship):
|
async def storage_relationship(self, relationship: Relationship):
|
||||||
"""将关系记录存储到数据库中"""
|
"""将关系记录存储到数据库中"""
|
||||||
user_id = relationship.user_id
|
user_id = relationship.user_id
|
||||||
platform = relationship.platform
|
platform = relationship.platform
|
||||||
platform = relationship.platform
|
|
||||||
nickname = relationship.nickname
|
nickname = relationship.nickname
|
||||||
relationship_value = relationship.relationship_value
|
relationship_value = relationship.relationship_value
|
||||||
gender = relationship.gender
|
gender = relationship.gender
|
||||||
@@ -309,10 +206,8 @@ class RelationshipManager:
|
|||||||
|
|
||||||
db = Database.get_instance()
|
db = Database.get_instance()
|
||||||
db.db.relationships.update_one(
|
db.db.relationships.update_one(
|
||||||
{'user_id': user_id, 'platform': platform},
|
|
||||||
{'user_id': user_id, 'platform': platform},
|
{'user_id': user_id, 'platform': platform},
|
||||||
{'$set': {
|
{'$set': {
|
||||||
'platform': platform,
|
|
||||||
'platform': platform,
|
'platform': platform,
|
||||||
'nickname': nickname,
|
'nickname': nickname,
|
||||||
'relationship_value': relationship_value,
|
'relationship_value': relationship_value,
|
||||||
@@ -323,27 +218,6 @@ class RelationshipManager:
|
|||||||
upsert=True
|
upsert=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_name(self,
|
|
||||||
user_id: int = None,
|
|
||||||
platform: str = None,
|
|
||||||
user_info: UserInfo = None) -> str:
|
|
||||||
"""获取用户昵称
|
|
||||||
Args:
|
|
||||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
|
||||||
platform: 平台(可选,如果提供user_info则不需要)
|
|
||||||
user_info: 用户信息对象(可选)
|
|
||||||
Returns:
|
|
||||||
str: 用户昵称
|
|
||||||
"""
|
|
||||||
# 确定user_id和platform
|
|
||||||
if user_info is not None:
|
|
||||||
user_id = user_info.user_id
|
|
||||||
platform = user_info.platform or 'qq'
|
|
||||||
else:
|
|
||||||
platform = platform or 'qq'
|
|
||||||
|
|
||||||
if user_id is None:
|
|
||||||
raise ValueError("必须提供user_id或user_info")
|
|
||||||
|
|
||||||
def get_name(self,
|
def get_name(self,
|
||||||
user_id: int = None,
|
user_id: int = None,
|
||||||
@@ -370,11 +244,6 @@ class RelationshipManager:
|
|||||||
# 确保user_id是整数类型
|
# 确保user_id是整数类型
|
||||||
user_id = int(user_id)
|
user_id = int(user_id)
|
||||||
key = (user_id, platform)
|
key = (user_id, platform)
|
||||||
if key in self.relationships:
|
|
||||||
return self.relationships[key].nickname
|
|
||||||
elif user_info is not None:
|
|
||||||
return user_info.user_nickname or user_info.user_cardname or "某人"
|
|
||||||
key = (user_id, platform)
|
|
||||||
if key in self.relationships:
|
if key in self.relationships:
|
||||||
return self.relationships[key].nickname
|
return self.relationships[key].nickname
|
||||||
elif user_info is not None:
|
elif user_info is not None:
|
||||||
|
|||||||
@@ -18,8 +18,9 @@ class MessageStorage:
|
|||||||
"time": message.message_info.time,
|
"time": message.message_info.time,
|
||||||
"chat_id":chat_stream.stream_id,
|
"chat_id":chat_stream.stream_id,
|
||||||
"chat_info": chat_stream.to_dict(),
|
"chat_info": chat_stream.to_dict(),
|
||||||
"detailed_plain_text": message.detailed_plain_text,
|
"user_info": message.message_info.user_info.to_dict(),
|
||||||
"processed_plain_text": message.processed_plain_text,
|
"processed_plain_text": message.processed_plain_text,
|
||||||
|
"detailed_plain_text": message.detailed_plain_text,
|
||||||
"topic": topic,
|
"topic": topic,
|
||||||
}
|
}
|
||||||
self.db.db.messages.insert_one(message_data)
|
self.db.db.messages.insert_one(message_data)
|
||||||
|
|||||||
@@ -11,7 +11,9 @@ from nonebot import get_driver
|
|||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from ..utils.typo_generator import ChineseTypoGenerator
|
from ..utils.typo_generator import ChineseTypoGenerator
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .message_cq import Message
|
from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message
|
||||||
|
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo
|
||||||
|
from .chat_stream import ChatStream
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -32,7 +34,7 @@ def db_message_to_str(message_dict: Dict) -> str:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def is_mentioned_bot_in_message(message: Message) -> bool:
|
def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
|
||||||
"""检查消息是否提到了机器人"""
|
"""检查消息是否提到了机器人"""
|
||||||
keywords = [global_config.BOT_NICKNAME]
|
keywords = [global_config.BOT_NICKNAME]
|
||||||
for keyword in keywords:
|
for keyword in keywords:
|
||||||
@@ -41,15 +43,6 @@ def is_mentioned_bot_in_message(message: Message) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_mentioned_bot_in_txt(message: str) -> bool:
|
|
||||||
"""检查消息是否提到了机器人"""
|
|
||||||
keywords = [global_config.BOT_NICKNAME]
|
|
||||||
for keyword in keywords:
|
|
||||||
if keyword in message:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def get_embedding(text):
|
async def get_embedding(text):
|
||||||
"""获取文本的embedding向量"""
|
"""获取文本的embedding向量"""
|
||||||
llm = LLM_request(model=global_config.embedding)
|
llm = LLM_request(model=global_config.embedding)
|
||||||
@@ -84,10 +77,10 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
|||||||
|
|
||||||
if closest_record and closest_record.get('memorized', 0) < 4:
|
if closest_record and closest_record.get('memorized', 0) < 4:
|
||||||
closest_time = closest_record['time']
|
closest_time = closest_record['time']
|
||||||
group_id = closest_record['group_id'] # 获取groupid
|
chat_id = closest_record['chat_id'] # 获取groupid
|
||||||
# 获取该时间戳之后的length条消息,且groupid相同
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
chat_records = list(db.db.messages.find(
|
chat_records = list(db.db.messages.find(
|
||||||
{"time": {"$gt": closest_time}, "group_id": group_id}
|
{"time": {"$gt": closest_time}, "chat_id": chat_id}
|
||||||
).sort('time', 1).limit(length))
|
).sort('time', 1).limit(length))
|
||||||
|
|
||||||
# 更新每条消息的memorized属性
|
# 更新每条消息的memorized属性
|
||||||
@@ -111,7 +104,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
|||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
|
||||||
"""从数据库获取群组最近的消息记录
|
"""从数据库获取群组最近的消息记录
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -125,35 +118,28 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
|
|
||||||
# 从数据库获取最近消息
|
# 从数据库获取最近消息
|
||||||
recent_messages = list(db.db.messages.find(
|
recent_messages = list(db.db.messages.find(
|
||||||
{"group_id": group_id},
|
{"chat_id": chat_id},
|
||||||
# {
|
|
||||||
# "time": 1,
|
|
||||||
# "user_id": 1,
|
|
||||||
# "user_nickname": 1,
|
|
||||||
# "message_id": 1,
|
|
||||||
# "raw_message": 1,
|
|
||||||
# "processed_text": 1
|
|
||||||
# }
|
|
||||||
).sort("time", -1).limit(limit))
|
).sort("time", -1).limit(limit))
|
||||||
|
|
||||||
if not recent_messages:
|
if not recent_messages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 转换为 Message对象列表
|
# 转换为 Message对象列表
|
||||||
from .message_cq import Message
|
|
||||||
message_objects = []
|
message_objects = []
|
||||||
for msg_data in recent_messages:
|
for msg_data in recent_messages:
|
||||||
try:
|
try:
|
||||||
|
chat_info=msg_data.get("chat_info",{})
|
||||||
|
chat_stream=ChatStream.from_dict(chat_info)
|
||||||
|
user_info=msg_data.get("user_info",{})
|
||||||
|
user_info=UserInfo.from_dict(user_info)
|
||||||
msg = Message(
|
msg = Message(
|
||||||
time=msg_data["time"],
|
|
||||||
user_id=msg_data["user_id"],
|
|
||||||
user_nickname=msg_data.get("user_nickname", ""),
|
|
||||||
message_id=msg_data["message_id"],
|
message_id=msg_data["message_id"],
|
||||||
raw_message=msg_data["raw_message"],
|
chat_stream=chat_stream,
|
||||||
|
time=msg_data["time"],
|
||||||
|
user_info=user_info,
|
||||||
processed_plain_text=msg_data.get("processed_text", ""),
|
processed_plain_text=msg_data.get("processed_text", ""),
|
||||||
group_id=group_id
|
detailed_plain_text=msg_data.get("detailed_plain_text", "")
|
||||||
)
|
)
|
||||||
await msg.initialize()
|
|
||||||
message_objects.append(msg)
|
message_objects.append(msg)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("[WARNING] 数据库中存在无效的消息")
|
print("[WARNING] 数据库中存在无效的消息")
|
||||||
@@ -164,13 +150,14 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
return message_objects
|
return message_objects
|
||||||
|
|
||||||
|
|
||||||
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
|
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
|
||||||
recent_messages = list(db.db.messages.find(
|
recent_messages = list(db.db.messages.find(
|
||||||
{"group_id": group_id},
|
{"chat_id": chat_stream_id},
|
||||||
{
|
{
|
||||||
"time": 1, # 返回时间字段
|
"time": 1, # 返回时间字段
|
||||||
"user_id": 1, # 返回用户ID字段
|
"chat_id":1,
|
||||||
"user_nickname": 1, # 返回用户昵称字段
|
"chat_info":1,
|
||||||
|
"user_info": 1,
|
||||||
"message_id": 1, # 返回消息ID字段
|
"message_id": 1, # 返回消息ID字段
|
||||||
"detailed_plain_text": 1 # 返回处理后的文本字段
|
"detailed_plain_text": 1 # 返回处理后的文本字段
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user