refractor: 几乎写完了,进入测试阶段
This commit is contained in:
@@ -10,18 +10,19 @@ from .config import global_config
|
||||
from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
|
||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||
from .llm_generator import ResponseGenerator
|
||||
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
from .message_cq import (
|
||||
Message,
|
||||
Message_Sending,
|
||||
Message_Thinking, # 导入 Message_Thinking 类
|
||||
MessageSet,
|
||||
MessageRecvCQ,
|
||||
MessageSendCQ,
|
||||
)
|
||||
from .chat_stream import chat_manager
|
||||
from .message_sender import message_manager # 导入新的消息管理器
|
||||
from .relationship_manager import relationship_manager
|
||||
from .storage import MessageStorage
|
||||
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
|
||||
from .utils_image import image_path_to_base64
|
||||
from .willing_manager import willing_manager # 导入意愿管理器
|
||||
|
||||
from .message_base import UserInfo, GroupInfo, Seg
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self):
|
||||
@@ -43,12 +44,9 @@ class ChatBot:
|
||||
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
||||
"""处理收到的群消息"""
|
||||
|
||||
if event.group_id not in global_config.talk_allowed_groups:
|
||||
return
|
||||
self.bot = bot # 更新 bot 实例
|
||||
|
||||
if event.user_id in global_config.ban_user_id:
|
||||
return
|
||||
|
||||
|
||||
group_info = await bot.get_group_info(group_id=event.group_id)
|
||||
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||
@@ -56,25 +54,42 @@ class ChatBot:
|
||||
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
||||
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
||||
|
||||
message = Message(
|
||||
group_id=event.group_id,
|
||||
user_id=event.user_id,
|
||||
message_cq=MessageRecvCQ(
|
||||
message_id=event.message_id,
|
||||
user_cardname=sender_info['card'],
|
||||
raw_message=str(event.original_message),
|
||||
plain_text=event.get_plaintext(),
|
||||
user_id=event.user_id,
|
||||
raw_message=str(event.original_message),
|
||||
group_id=event.group_id,
|
||||
reply_message=event.reply,
|
||||
platform='qq'
|
||||
)
|
||||
await message.initialize()
|
||||
|
||||
message_json=message_cq.to_dict()
|
||||
message=MessageRecv(**message_json)
|
||||
await message.process()
|
||||
groupinfo=message.message_info.group_info
|
||||
userinfo=message.message_info.user_info
|
||||
messageinfo=message.message_info
|
||||
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
|
||||
if groupinfo:
|
||||
if groupinfo.group_id not in global_config.talk_allowed_groups:
|
||||
return
|
||||
else:
|
||||
if userinfo:
|
||||
if userinfo.user_id in []:
|
||||
pass
|
||||
else:
|
||||
return
|
||||
else:
|
||||
return
|
||||
if userinfo.user_id in global_config.ban_user_id:
|
||||
return
|
||||
# 过滤词
|
||||
for word in global_config.ban_words:
|
||||
if word in message.detailed_plain_text:
|
||||
logger.info(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}")
|
||||
if word in message.processed_plain_text:
|
||||
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
|
||||
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered")
|
||||
return
|
||||
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
||||
|
||||
|
||||
|
||||
@@ -88,44 +103,58 @@ class ChatBot:
|
||||
await self.storage.store_message(message, topic[0] if topic else None)
|
||||
|
||||
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
|
||||
reply_probability = willing_manager.change_reply_willing_received(
|
||||
event.group_id,
|
||||
topic[0] if topic else None,
|
||||
is_mentioned,
|
||||
global_config,
|
||||
event.user_id,
|
||||
message.is_emoji,
|
||||
interested_rate
|
||||
reply_probability = await willing_manager.change_reply_willing_received(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
topic=topic[0] if topic else None,
|
||||
is_mentioned_bot=is_mentioned,
|
||||
config=global_config,
|
||||
is_emoji=message.is_emoji,
|
||||
interested_rate=interested_rate
|
||||
)
|
||||
current_willing = willing_manager.get_willing(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo
|
||||
)
|
||||
current_willing = willing_manager.get_willing(event.group_id)
|
||||
|
||||
|
||||
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
|
||||
print(f"\033[1;32m[{current_time}][{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
|
||||
|
||||
response = ""
|
||||
|
||||
if random() < reply_probability:
|
||||
|
||||
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.BOT_QQ,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
platform=messageinfo.platform
|
||||
)
|
||||
tinking_time_point = round(time.time(), 2)
|
||||
think_id = 'mt' + str(tinking_time_point)
|
||||
thinking_message = Message_Thinking(message=message,message_id=think_id)
|
||||
thinking_message = MessageThinking.from_chat_stream(
|
||||
chat_stream=chat,
|
||||
message_id=think_id,
|
||||
reply=message
|
||||
)
|
||||
|
||||
message_manager.add_message(thinking_message)
|
||||
|
||||
willing_manager.change_reply_willing_sent(thinking_message.group_id)
|
||||
willing_manager.change_reply_willing_sent(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo
|
||||
)
|
||||
|
||||
response,raw_content = await self.gpt.generate_response(message)
|
||||
|
||||
if response:
|
||||
container = message_manager.get_container(event.group_id)
|
||||
container = message_manager.get_container(chat.stream_id)
|
||||
thinking_message = None
|
||||
# 找到message,删除
|
||||
for msg in container.messages:
|
||||
if isinstance(msg, Message_Thinking) and msg.message_id == think_id:
|
||||
if isinstance(msg, MessageThinking) and msg.message_id == think_id:
|
||||
thinking_message = msg
|
||||
container.messages.remove(msg)
|
||||
# print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
|
||||
break
|
||||
|
||||
# 如果找不到思考消息,直接返回
|
||||
@@ -135,11 +164,10 @@ class ChatBot:
|
||||
|
||||
#记录开始思考的时间,避免从思考到回复的时间太久
|
||||
thinking_start_time = thinking_message.thinking_start_time
|
||||
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
|
||||
message_set = MessageSet(chat, think_id)
|
||||
#计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
||||
accu_typing_time = 0
|
||||
|
||||
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
|
||||
mark_head = False
|
||||
for msg in response:
|
||||
# print(f"\033[1;32m[回复内容]\033[0m {msg}")
|
||||
@@ -148,22 +176,16 @@ class ChatBot:
|
||||
accu_typing_time += typing_time
|
||||
timepoint = tinking_time_point + accu_typing_time
|
||||
|
||||
bot_message = Message_Sending(
|
||||
group_id=event.group_id,
|
||||
user_id=global_config.BOT_QQ,
|
||||
message_segment = Seg(type='text', data=msg)
|
||||
bot_message = MessageSending(
|
||||
message_id=think_id,
|
||||
raw_message=msg,
|
||||
plain_text=msg,
|
||||
processed_plain_text=msg,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
group_name=message.group_name,
|
||||
time=timepoint, #记录了回复生成的时间
|
||||
thinking_start_time=thinking_start_time, #记录了思考开始的时间
|
||||
reply_message_id=message.message_id
|
||||
chat_stream=chat,
|
||||
message_segment=message_segment,
|
||||
reply=message,
|
||||
is_head=not mark_head,
|
||||
is_emoji=False
|
||||
)
|
||||
await bot_message.initialize()
|
||||
if not mark_head:
|
||||
bot_message.is_head = True
|
||||
mark_head = True
|
||||
message_set.add_message(bot_message)
|
||||
|
||||
@@ -180,30 +202,22 @@ class ChatBot:
|
||||
if emoji_raw != None:
|
||||
emoji_path,discription = emoji_raw
|
||||
|
||||
emoji_cq = cq_code_tool.create_emoji_cq(emoji_path)
|
||||
emoji_cq = image_path_to_base64(emoji_path)
|
||||
|
||||
if random() < 0.5:
|
||||
bot_response_time = tinking_time_point - 1
|
||||
else:
|
||||
bot_response_time = bot_response_time + 1
|
||||
|
||||
bot_message = Message_Sending(
|
||||
group_id=event.group_id,
|
||||
user_id=global_config.BOT_QQ,
|
||||
message_id=0,
|
||||
raw_message=emoji_cq,
|
||||
plain_text=emoji_cq,
|
||||
processed_plain_text=emoji_cq,
|
||||
detailed_plain_text=discription,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
group_name=message.group_name,
|
||||
time=bot_response_time,
|
||||
is_emoji=True,
|
||||
translate_cq=False,
|
||||
thinking_start_time=thinking_start_time,
|
||||
# reply_message_id=message.message_id
|
||||
message_segment = Seg(type='emoji', data=emoji_cq)
|
||||
bot_message = MessageSending(
|
||||
message_id=think_id,
|
||||
chat_stream=chat,
|
||||
message_segment=message_segment,
|
||||
reply=message,
|
||||
is_head=False,
|
||||
is_emoji=True
|
||||
)
|
||||
await bot_message.initialize()
|
||||
message_manager.add_message(bot_message)
|
||||
emotion = await self.gpt._get_emotion_tags(raw_content)
|
||||
print(f"为 '{response}' 获取到的情感标签为:{emotion}")
|
||||
@@ -219,8 +233,12 @@ class ChatBot:
|
||||
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
|
||||
# 使用情绪管理器更新情绪
|
||||
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
|
||||
|
||||
# willing_manager.change_reply_willing_after_sent(event.group_id)
|
||||
|
||||
willing_manager.change_reply_willing_after_sent(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo
|
||||
)
|
||||
|
||||
# 创建全局ChatBot实例
|
||||
chat_bot = ChatBot()
|
||||
209
src/plugins/chat/chat_stream.py
Normal file
209
src/plugins/chat/chat_stream.py
Normal file
@@ -0,0 +1,209 @@
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Optional, Dict, Tuple
|
||||
import hashlib
|
||||
|
||||
from loguru import logger
|
||||
from ...common.database import Database
|
||||
from .message_base import UserInfo, GroupInfo
|
||||
|
||||
|
||||
class ChatStream:
|
||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||
def __init__(self,
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: dict = None):
|
||||
self.stream_id = stream_id
|
||||
self.platform = platform
|
||||
self.user_info = user_info
|
||||
self.group_info = group_info
|
||||
self.create_time = data.get('create_time', int(time.time())) if data else int(time.time())
|
||||
self.last_active_time = data.get('last_active_time', self.create_time) if data else self.create_time
|
||||
self.saved = False
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
result = {
|
||||
'stream_id': self.stream_id,
|
||||
'platform': self.platform,
|
||||
'user_info': self.user_info.to_dict() if self.user_info else None,
|
||||
'group_info': self.group_info.to_dict() if self.group_info else None,
|
||||
'create_time': self.create_time,
|
||||
'last_active_time': self.last_active_time
|
||||
}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'ChatStream':
|
||||
"""从字典创建实例"""
|
||||
user_info = UserInfo(**data.get('user_info', {})) if data.get('user_info') else None
|
||||
group_info = GroupInfo(**data.get('group_info', {})) if data.get('group_info') else None
|
||||
|
||||
return cls(
|
||||
stream_id=data['stream_id'],
|
||||
platform=data['platform'],
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data
|
||||
)
|
||||
|
||||
def update_active_time(self):
|
||||
"""更新最后活跃时间"""
|
||||
self.last_active_time = int(time.time())
|
||||
self.saved = False
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器,管理所有聊天流"""
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.db = Database.get_instance()
|
||||
self._ensure_collection()
|
||||
self._initialized = True
|
||||
# 在事件循环中启动初始化
|
||||
asyncio.create_task(self._initialize())
|
||||
# 启动自动保存任务
|
||||
asyncio.create_task(self._auto_save_task())
|
||||
|
||||
async def _initialize(self):
|
||||
"""异步初始化"""
|
||||
try:
|
||||
await self.load_all_streams()
|
||||
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天管理器启动失败: {str(e)}")
|
||||
|
||||
async def _auto_save_task(self):
|
||||
"""定期自动保存所有聊天流"""
|
||||
while True:
|
||||
await asyncio.sleep(300) # 每5分钟保存一次
|
||||
try:
|
||||
await self._save_all_streams()
|
||||
logger.info("聊天流自动保存完成")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||
|
||||
def _ensure_collection(self):
|
||||
"""确保数据库集合存在并创建索引"""
|
||||
if 'chat_streams' not in self.db.db.list_collection_names():
|
||||
self.db.db.create_collection('chat_streams')
|
||||
# 创建索引
|
||||
self.db.db.chat_streams.create_index([('stream_id', 1)], unique=True)
|
||||
self.db.db.chat_streams.create_index([
|
||||
('platform', 1),
|
||||
('user_info.user_id', 1),
|
||||
('group_info.group_id', 1)
|
||||
])
|
||||
|
||||
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
# 组合关键信息
|
||||
components = [
|
||||
platform,
|
||||
str(user_info.user_id),
|
||||
str(group_info.group_id) if group_info else 'private'
|
||||
]
|
||||
|
||||
# 使用MD5生成唯一ID
|
||||
key = '_'.join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def get_or_create_stream(self,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None) -> ChatStream:
|
||||
"""获取或创建聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
user_info: 用户信息
|
||||
group_info: 群组信息(可选)
|
||||
|
||||
Returns:
|
||||
ChatStream: 聊天流对象
|
||||
"""
|
||||
# 生成stream_id
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
|
||||
# 检查内存中是否存在
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
# 更新用户信息和群组信息
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
stream.update_active_time()
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
data = self.db.db.chat_streams.find_one({'stream_id': stream_id})
|
||||
if data:
|
||||
stream = ChatStream.from_dict(data)
|
||||
# 更新用户信息和群组信息
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
stream.update_active_time()
|
||||
else:
|
||||
# 创建新的聊天流
|
||||
stream = ChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info
|
||||
)
|
||||
|
||||
# 保存到内存和数据库
|
||||
self.streams[stream_id] = stream
|
||||
await self._save_stream(stream)
|
||||
return stream
|
||||
|
||||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||
"""通过stream_id获取聊天流"""
|
||||
return self.streams.get(stream_id)
|
||||
|
||||
def get_stream_by_info(self,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None) -> Optional[ChatStream]:
|
||||
"""通过信息获取聊天流"""
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
return self.streams.get(stream_id)
|
||||
|
||||
async def _save_stream(self, stream: ChatStream):
|
||||
"""保存聊天流到数据库"""
|
||||
if not stream.saved:
|
||||
self.db.db.chat_streams.update_one(
|
||||
{'stream_id': stream.stream_id},
|
||||
{'$set': stream.to_dict()},
|
||||
upsert=True
|
||||
)
|
||||
stream.saved = True
|
||||
|
||||
async def _save_all_streams(self):
|
||||
"""保存所有聊天流"""
|
||||
for stream in self.streams.values():
|
||||
await self._save_stream(stream)
|
||||
|
||||
async def load_all_streams(self):
|
||||
"""从数据库加载所有聊天流"""
|
||||
all_streams = self.db.db.chat_streams.find({})
|
||||
for data in all_streams:
|
||||
stream = ChatStream.from_dict(data)
|
||||
self.streams[stream.stream_id] = stream
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
chat_manager = ChatManager()
|
||||
@@ -373,6 +373,24 @@ class CQCode_tool:
|
||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
||||
|
||||
@staticmethod
|
||||
def create_emoji_cq_base64(base64_data: str) -> str:
|
||||
"""
|
||||
创建表情包CQ码
|
||||
Args:
|
||||
base64_data: base64编码的表情包数据
|
||||
Returns:
|
||||
表情包CQ码字符串
|
||||
"""
|
||||
# 转义base64数据
|
||||
escaped_base64 = base64_data.replace('&', '&') \
|
||||
.replace('[', '[') \
|
||||
.replace(']', ']') \
|
||||
.replace(',', ',')
|
||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
import base64
|
||||
import hashlib
|
||||
|
||||
@@ -92,7 +92,7 @@ class EmojiManager:
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
|
||||
async def get_emoji_for_text(self, text: str) -> Optional[str]:
|
||||
async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str,str]]:
|
||||
"""根据文本内容获取相关表情包
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
@@ -5,11 +5,10 @@ from typing import Dict, ForwardRef, List, Optional, Union
|
||||
import urllib3
|
||||
from loguru import logger
|
||||
|
||||
from .cq_code import CQCode, cq_code_tool
|
||||
from .utils_cq import parse_cq_code
|
||||
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
|
||||
from .utils_image import image_manager
|
||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
||||
from .chat_stream import ChatStream
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
@@ -115,36 +114,17 @@ class MessageProcessBase(MessageBase):
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
user_id: int,
|
||||
group_id: Optional[int] = None,
|
||||
platform: str = "qq",
|
||||
chat_stream: ChatStream,
|
||||
message_segment: Optional[Seg] = None,
|
||||
reply: Optional['MessageRecv'] = None
|
||||
):
|
||||
# 构造用户信息
|
||||
user_info = UserInfo(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
user_nickname=get_user_nickname(user_id),
|
||||
user_cardname=get_user_cardname(user_id) if group_id else None
|
||||
)
|
||||
|
||||
# 构造群组信息(如果有)
|
||||
group_info = None
|
||||
if group_id:
|
||||
group_info = GroupInfo(
|
||||
platform=platform,
|
||||
group_id=group_id,
|
||||
group_name=get_groupname(group_id)
|
||||
)
|
||||
|
||||
# 构造基础消息信息
|
||||
message_info = BaseMessageInfo(
|
||||
platform=platform,
|
||||
platform=chat_stream.platform,
|
||||
message_id=message_id,
|
||||
time=int(time.time()),
|
||||
group_info=group_info,
|
||||
user_info=user_info
|
||||
group_info=chat_stream.group_info,
|
||||
user_info=chat_stream.user_info
|
||||
)
|
||||
|
||||
# 调用父类初始化
|
||||
@@ -241,17 +221,13 @@ class MessageThinking(MessageProcessBase):
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
user_id: int,
|
||||
group_id: Optional[int] = None,
|
||||
platform: str = "qq",
|
||||
chat_stream: ChatStream,
|
||||
reply: Optional['MessageRecv'] = None
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
platform=platform,
|
||||
chat_stream=chat_stream,
|
||||
message_segment=None, # 思考状态不需要消息段
|
||||
reply=reply
|
||||
)
|
||||
@@ -259,6 +235,15 @@ class MessageThinking(MessageProcessBase):
|
||||
# 思考状态特有属性
|
||||
self.interrupt = False
|
||||
|
||||
@classmethod
|
||||
def from_chat_stream(cls, chat_stream: ChatStream, message_id: str, reply: Optional['MessageRecv'] = None) -> 'MessageThinking':
|
||||
"""从聊天流创建思考状态消息"""
|
||||
return cls(
|
||||
message_id=message_id,
|
||||
chat_stream=chat_stream,
|
||||
reply=reply
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class MessageSending(MessageProcessBase):
|
||||
"""发送状态的消息类"""
|
||||
@@ -266,19 +251,16 @@ class MessageSending(MessageProcessBase):
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
user_id: int,
|
||||
chat_stream: ChatStream,
|
||||
message_segment: Seg,
|
||||
group_id: Optional[int] = None,
|
||||
reply: Optional['MessageRecv'] = None,
|
||||
platform: str = "qq",
|
||||
is_head: bool = False
|
||||
is_head: bool = False,
|
||||
is_emoji: bool = False
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
message_id=message_id,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
platform=platform,
|
||||
chat_stream=chat_stream,
|
||||
message_segment=message_segment,
|
||||
reply=reply
|
||||
)
|
||||
@@ -286,6 +268,12 @@ class MessageSending(MessageProcessBase):
|
||||
# 发送状态特有属性
|
||||
self.reply_to_message_id = reply.message_info.message_id if reply else None
|
||||
self.is_head = is_head
|
||||
self.is_emoji = is_emoji
|
||||
if is_head:
|
||||
self.message_segment = Seg(type='seglist', data=[
|
||||
Seg(type='reply', data=reply.message_info.message_id),
|
||||
self.message_segment
|
||||
])
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本"""
|
||||
@@ -298,26 +286,24 @@ class MessageSending(MessageProcessBase):
|
||||
cls,
|
||||
thinking: MessageThinking,
|
||||
message_segment: Seg,
|
||||
reply: Optional['MessageRecv'] = None,
|
||||
is_head: bool = False
|
||||
is_head: bool = False,
|
||||
is_emoji: bool = False
|
||||
) -> 'MessageSending':
|
||||
"""从思考状态消息创建发送状态消息"""
|
||||
return cls(
|
||||
message_id=thinking.message_info.message_id,
|
||||
user_id=thinking.message_info.user_info.user_id,
|
||||
chat_stream=thinking.chat_stream,
|
||||
message_segment=message_segment,
|
||||
group_id=thinking.message_info.group_info.group_id if thinking.message_info.group_info else None,
|
||||
reply=reply or thinking.reply,
|
||||
platform=thinking.message_info.platform,
|
||||
is_head=is_head
|
||||
reply=thinking.reply,
|
||||
is_head=is_head,
|
||||
is_emoji=is_emoji
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class MessageSet:
|
||||
"""消息集合类,可以存储多个发送消息"""
|
||||
def __init__(self, group_id: int, user_id: int, message_id: str):
|
||||
self.group_id = group_id
|
||||
self.user_id = user_id
|
||||
def __init__(self, chat_stream: ChatStream, message_id: str):
|
||||
self.chat_stream = chat_stream
|
||||
self.message_id = message_id
|
||||
self.messages: List[MessageSending] = []
|
||||
self.time = round(time.time(), 2)
|
||||
|
||||
@@ -176,7 +176,7 @@ class MessageSendCQ(MessageCQ):
|
||||
elif seg.type == 'image':
|
||||
# 如果是base64图片数据
|
||||
if seg.data.startswith(('data:', 'base64:')):
|
||||
return f"[CQ:image,file=base64://{seg.data}]"
|
||||
return cq_code_tool.create_emoji_cq_base64(seg.data)
|
||||
# 如果是表情包(本地文件)
|
||||
return cq_code_tool.create_emoji_cq(seg.data)
|
||||
elif seg.type == 'at':
|
||||
|
||||
@@ -5,10 +5,11 @@ from typing import Dict, List, Optional, Union
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
|
||||
from .cq_code import cq_code_tool
|
||||
from .message_cq import Message, Message_Sending, Message_Thinking, MessageSet
|
||||
from .message_cq import MessageSendCQ
|
||||
from .message import MessageSending, MessageThinking, MessageRecv,MessageSet
|
||||
from .storage import MessageStorage
|
||||
from .utils import calculate_typing_time
|
||||
from .config import global_config
|
||||
from .chat_stream import chat_manager
|
||||
|
||||
|
||||
class Message_Sender:
|
||||
@@ -21,66 +22,59 @@ class Message_Sender:
|
||||
def set_bot(self, bot: Bot):
|
||||
"""设置当前bot实例"""
|
||||
self._current_bot = bot
|
||||
|
||||
async def send_group_message(
|
||||
self,
|
||||
group_id: int,
|
||||
send_text: str,
|
||||
auto_escape: bool = False,
|
||||
reply_message_id: int = None,
|
||||
at_user_id: int = None
|
||||
) -> None:
|
||||
|
||||
if not self._current_bot:
|
||||
raise RuntimeError("Bot未设置,请先调用set_bot方法设置bot实例")
|
||||
|
||||
message = send_text
|
||||
|
||||
# 如果需要回复
|
||||
if reply_message_id:
|
||||
reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
|
||||
message = reply_cq + message
|
||||
|
||||
# 如果需要at
|
||||
# if at_user_id:
|
||||
# at_cq = cq_code_tool.create_at_cq(at_user_id)
|
||||
# message = at_cq + " " + message
|
||||
|
||||
|
||||
typing_time = calculate_typing_time(message)
|
||||
if typing_time > 10:
|
||||
typing_time = 10
|
||||
await asyncio.sleep(typing_time)
|
||||
|
||||
# 发送消息
|
||||
try:
|
||||
await self._current_bot.send_group_msg(
|
||||
group_id=group_id,
|
||||
message=message,
|
||||
auto_escape=auto_escape
|
||||
async def send_message(
|
||||
self,
|
||||
message: MessageSending,
|
||||
) -> None:
|
||||
"""发送消息"""
|
||||
if isinstance(message, MessageSending):
|
||||
message_send=MessageSendCQ(
|
||||
message_id=message.message_id,
|
||||
user_id=message.message_info.user_info.user_id,
|
||||
message_segment=message.message_segment,
|
||||
reply=message.reply
|
||||
)
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
|
||||
except Exception as e:
|
||||
print(f"发生错误 {e}")
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
|
||||
if message.message_info.group_info:
|
||||
try:
|
||||
await self._current_bot.send_group_msg(
|
||||
group_id=message.message_info.group_info.group_id,
|
||||
message=message_send.raw_message,
|
||||
auto_escape=False
|
||||
)
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
|
||||
except Exception as e:
|
||||
print(f"发生错误 {e}")
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
|
||||
else:
|
||||
try:
|
||||
await self._current_bot.send_private_msg(
|
||||
user_id=message.message_info.user_info.user_id,
|
||||
message=message_send.raw_message,
|
||||
auto_escape=False
|
||||
)
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
|
||||
except Exception as e:
|
||||
print(f"发生错误 {e}")
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
|
||||
|
||||
|
||||
class MessageContainer:
|
||||
"""单个群的发送/思考消息容器"""
|
||||
def __init__(self, group_id: int, max_size: int = 100):
|
||||
self.group_id = group_id
|
||||
"""单个聊天流的发送/思考消息容器"""
|
||||
def __init__(self, chat_id: str, max_size: int = 100):
|
||||
self.chat_id = chat_id
|
||||
self.max_size = max_size
|
||||
self.messages = []
|
||||
self.last_send_time = 0
|
||||
self.thinking_timeout = 20 # 思考超时时间(秒)
|
||||
|
||||
def get_timeout_messages(self) -> List[Message_Sending]:
|
||||
def get_timeout_messages(self) -> List[MessageSending]:
|
||||
"""获取所有超时的Message_Sending对象(思考时间超过30秒),按thinking_start_time排序"""
|
||||
current_time = time.time()
|
||||
timeout_messages = []
|
||||
|
||||
for msg in self.messages:
|
||||
if isinstance(msg, Message_Sending):
|
||||
if isinstance(msg, MessageSending):
|
||||
if current_time - msg.thinking_start_time > self.thinking_timeout:
|
||||
timeout_messages.append(msg)
|
||||
|
||||
@@ -89,7 +83,7 @@ class MessageContainer:
|
||||
|
||||
return timeout_messages
|
||||
|
||||
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]:
|
||||
def get_earliest_message(self) -> Optional[Union[MessageThinking, MessageSending]]:
|
||||
"""获取thinking_start_time最早的消息对象"""
|
||||
if not self.messages:
|
||||
return None
|
||||
@@ -102,16 +96,15 @@ class MessageContainer:
|
||||
earliest_message = msg
|
||||
return earliest_message
|
||||
|
||||
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None:
|
||||
def add_message(self, message: Union[MessageThinking, MessageSending]) -> None:
|
||||
"""添加消息到队列"""
|
||||
# print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
|
||||
if isinstance(message, MessageSet):
|
||||
for single_message in message.messages:
|
||||
self.messages.append(single_message)
|
||||
else:
|
||||
self.messages.append(message)
|
||||
|
||||
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool:
|
||||
def remove_message(self, message: Union[MessageThinking, MessageSending]) -> bool:
|
||||
"""移除消息,如果消息存在则返回True,否则返回False"""
|
||||
try:
|
||||
if message in self.messages:
|
||||
@@ -126,40 +119,42 @@ class MessageContainer:
|
||||
"""检查是否有待发送的消息"""
|
||||
return bool(self.messages)
|
||||
|
||||
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]:
|
||||
def get_all_messages(self) -> List[Union[MessageSending, MessageThinking]]:
|
||||
"""获取所有消息"""
|
||||
return list(self.messages)
|
||||
|
||||
|
||||
class MessageManager:
|
||||
"""管理所有群的消息容器"""
|
||||
"""管理所有聊天流的消息容器"""
|
||||
def __init__(self):
|
||||
self.containers: Dict[int, MessageContainer] = {}
|
||||
self.containers: Dict[str, MessageContainer] = {} # chat_id -> MessageContainer
|
||||
self.storage = MessageStorage()
|
||||
self._running = True
|
||||
|
||||
def get_container(self, group_id: int) -> MessageContainer:
|
||||
"""获取或创建群的消息容器"""
|
||||
if group_id not in self.containers:
|
||||
self.containers[group_id] = MessageContainer(group_id)
|
||||
return self.containers[group_id]
|
||||
def get_container(self, chat_id: str) -> MessageContainer:
|
||||
"""获取或创建聊天流的消息容器"""
|
||||
if chat_id not in self.containers:
|
||||
self.containers[chat_id] = MessageContainer(chat_id)
|
||||
return self.containers[chat_id]
|
||||
|
||||
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None:
|
||||
container = self.get_container(message.group_id)
|
||||
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
|
||||
chat_stream = chat_manager.get_stream_by_info(
|
||||
platform=message.message_info.platform,
|
||||
user_info=message.message_info.user_info,
|
||||
group_info=message.message_info.group_info
|
||||
)
|
||||
if not chat_stream:
|
||||
raise ValueError("无法找到对应的聊天流")
|
||||
container = self.get_container(chat_stream.stream_id)
|
||||
container.add_message(message)
|
||||
|
||||
async def process_group_messages(self, group_id: int):
|
||||
"""处理群消息"""
|
||||
# if int(time.time() / 3) == time.time() / 3:
|
||||
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
|
||||
container = self.get_container(group_id)
|
||||
async def process_chat_messages(self, chat_id: str):
|
||||
"""处理聊天流消息"""
|
||||
container = self.get_container(chat_id)
|
||||
if container.has_messages():
|
||||
#最早的对象,可能是思考消息,也可能是发送消息
|
||||
message_earliest = container.get_earliest_message() #一个message_thinking or message_sending
|
||||
message_earliest = container.get_earliest_message()
|
||||
|
||||
#如果是思考消息
|
||||
if isinstance(message_earliest, Message_Thinking):
|
||||
#优先等待这条消息
|
||||
if isinstance(message_earliest, MessageThinking):
|
||||
message_earliest.update_thinking_time()
|
||||
thinking_time = message_earliest.thinking_time
|
||||
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}秒\033[K\r", end='', flush=True)
|
||||
@@ -168,42 +163,36 @@ class MessageManager:
|
||||
if thinking_time > global_config.thinking_timeout:
|
||||
print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息")
|
||||
container.remove_message(message_earliest)
|
||||
else:# 如果不是message_thinking就只能是message_sending
|
||||
else:
|
||||
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
|
||||
#直接发,等什么呢
|
||||
if message_earliest.is_head and message_earliest.update_thinking_time() >30:
|
||||
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id)
|
||||
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
|
||||
await message_sender.send_message(message_earliest)
|
||||
else:
|
||||
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False)
|
||||
#移除消息
|
||||
await message_sender.send_message(message_earliest)
|
||||
|
||||
if message_earliest.is_emoji:
|
||||
message_earliest.processed_plain_text = "[表情包]"
|
||||
await self.storage.store_message(message_earliest, None)
|
||||
|
||||
container.remove_message(message_earliest)
|
||||
|
||||
#获取并处理超时消息
|
||||
message_timeout = container.get_timeout_messages() #也许是一堆message_sending
|
||||
message_timeout = container.get_timeout_messages()
|
||||
if message_timeout:
|
||||
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息")
|
||||
for msg in message_timeout:
|
||||
if msg == message_earliest:
|
||||
continue # 跳过已经处理过的消息
|
||||
continue
|
||||
|
||||
try:
|
||||
#发送
|
||||
if msg.is_head and msg.update_thinking_time() >30:
|
||||
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
|
||||
if msg.is_head and msg.update_thinking_time() > 30:
|
||||
await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
|
||||
else:
|
||||
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False)
|
||||
|
||||
await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False)
|
||||
|
||||
#如果是表情包,则替换为"[表情包]"
|
||||
if msg.is_emoji:
|
||||
msg.processed_plain_text = "[表情包]"
|
||||
await self.storage.store_message(msg, None)
|
||||
|
||||
# 安全地移除消息
|
||||
if not container.remove_message(msg):
|
||||
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息")
|
||||
except Exception as e:
|
||||
@@ -215,8 +204,8 @@ class MessageManager:
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
tasks = []
|
||||
for group_id in self.containers.keys():
|
||||
tasks.append(self.process_group_messages(group_id))
|
||||
for chat_id in self.containers.keys():
|
||||
tasks.append(self.process_chat_messages(chat_id))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...common.database import Database
|
||||
from .message_base import UserInfo
|
||||
|
||||
|
||||
class Impression:
|
||||
@@ -13,27 +14,36 @@ class Impression:
|
||||
|
||||
class Relationship:
|
||||
user_id: int = None
|
||||
# impression: Impression = None
|
||||
# group_id: int = None
|
||||
# group_name: str = None
|
||||
platform: str = None
|
||||
gender: str = None
|
||||
age: int = None
|
||||
nickname: str = None
|
||||
relationship_value: float = None
|
||||
saved = False
|
||||
|
||||
def __init__(self, user_id: int, data=None, **kwargs):
|
||||
def __init__(self, user_id: int = None, data: dict = None, user_info: UserInfo = None, **kwargs):
|
||||
if isinstance(data, dict):
|
||||
# 如果输入是字典,使用字典解析
|
||||
self.user_id = data.get('user_id')
|
||||
self.platform = data.get('platform', 'qq')
|
||||
self.gender = data.get('gender')
|
||||
self.age = data.get('age')
|
||||
self.nickname = data.get('nickname')
|
||||
self.relationship_value = data.get('relationship_value', 0.0)
|
||||
self.saved = data.get('saved', False)
|
||||
elif user_info is not None:
|
||||
# 如果输入是UserInfo对象
|
||||
self.user_id = user_info.user_id
|
||||
self.platform = user_info.platform or 'qq'
|
||||
self.nickname = user_info.user_nickname or user_info.user_cardname or "某人"
|
||||
self.relationship_value = kwargs.get('relationship_value', 0.0)
|
||||
self.gender = kwargs.get('gender')
|
||||
self.age = kwargs.get('age')
|
||||
self.saved = kwargs.get('saved', False)
|
||||
else:
|
||||
# 如果是直接传入属性值
|
||||
self.user_id = kwargs.get('user_id')
|
||||
self.platform = kwargs.get('platform', 'qq')
|
||||
self.gender = kwargs.get('gender')
|
||||
self.age = kwargs.get('age')
|
||||
self.nickname = kwargs.get('nickname')
|
||||
@@ -41,32 +51,63 @@ class Relationship:
|
||||
self.saved = kwargs.get('saved', False)
|
||||
|
||||
|
||||
|
||||
|
||||
class RelationshipManager:
|
||||
def __init__(self):
|
||||
self.relationships: dict[int, Relationship] = {}
|
||||
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
||||
|
||||
async def update_relationship(self, user_id: int, data=None, **kwargs):
|
||||
async def update_relationship(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None,
|
||||
data: dict = None,
|
||||
**kwargs) -> Optional[Relationship]:
|
||||
"""更新或创建关系
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
data: 字典格式的数据(可选)
|
||||
**kwargs: 其他参数
|
||||
Returns:
|
||||
Relationship: 关系对象
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
# 使用(user_id, platform)作为键
|
||||
key = (user_id, platform)
|
||||
|
||||
# 检查是否在内存中已存在
|
||||
relationship = self.relationships.get(user_id)
|
||||
relationship = self.relationships.get(key)
|
||||
if relationship:
|
||||
# 如果存在,更新现有对象
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if hasattr(relationship, key) and value is not None:
|
||||
setattr(relationship, key, value)
|
||||
for k, value in data.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
else:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(relationship, key) and value is not None:
|
||||
setattr(relationship, key, value)
|
||||
for k, value in kwargs.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
else:
|
||||
# 如果不存在,创建新对象
|
||||
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id, **kwargs)
|
||||
self.relationships[user_id] = relationship
|
||||
|
||||
# 更新 id_name_nickname_table
|
||||
# self.id_name_nickname_table[user_id] = [relationship.nickname] # 别称设置为空列表
|
||||
if user_info is not None:
|
||||
relationship = Relationship(user_info=user_info, **kwargs)
|
||||
elif isinstance(data, dict):
|
||||
data['platform'] = platform
|
||||
relationship = Relationship(user_id=user_id, data=data)
|
||||
else:
|
||||
kwargs['platform'] = platform
|
||||
kwargs['user_id'] = user_id
|
||||
relationship = Relationship(**kwargs)
|
||||
self.relationships[key] = relationship
|
||||
|
||||
# 保存到数据库
|
||||
await self.storage_relationship(relationship)
|
||||
@@ -74,33 +115,87 @@ class RelationshipManager:
|
||||
|
||||
return relationship
|
||||
|
||||
async def update_relationship_value(self, user_id: int, **kwargs):
|
||||
async def update_relationship_value(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None,
|
||||
**kwargs) -> Optional[Relationship]:
|
||||
"""更新关系值
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
**kwargs: 其他参数
|
||||
Returns:
|
||||
Relationship: 关系对象
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
# 使用(user_id, platform)作为键
|
||||
key = (user_id, platform)
|
||||
|
||||
# 检查是否在内存中已存在
|
||||
relationship = self.relationships.get(user_id)
|
||||
relationship = self.relationships.get(key)
|
||||
if relationship:
|
||||
for key, value in kwargs.items():
|
||||
if key == 'relationship_value':
|
||||
for k, value in kwargs.items():
|
||||
if k == 'relationship_value':
|
||||
relationship.relationship_value += value
|
||||
await self.storage_relationship(relationship)
|
||||
relationship.saved = True
|
||||
return relationship
|
||||
else:
|
||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id} 不存在,无法更新")
|
||||
# 如果不存在且提供了user_info,则创建新的关系
|
||||
if user_info is not None:
|
||||
return await self.update_relationship(user_info=user_info, **kwargs)
|
||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
||||
return None
|
||||
|
||||
|
||||
def get_relationship(self, user_id: int) -> Optional[Relationship]:
|
||||
"""获取用户关系对象"""
|
||||
if user_id in self.relationships:
|
||||
return self.relationships[user_id]
|
||||
def get_relationship(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None) -> Optional[Relationship]:
|
||||
"""获取用户关系对象
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
Returns:
|
||||
Relationship: 关系对象
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
key = (user_id, platform)
|
||||
if key in self.relationships:
|
||||
return self.relationships[key]
|
||||
else:
|
||||
return 0
|
||||
|
||||
async def load_relationship(self, data: dict) -> Relationship:
|
||||
"""从数据库加载或创建新的关系对象"""
|
||||
rela = Relationship(user_id=data['user_id'], data=data)
|
||||
"""从数据库加载或创建新的关系对象"""
|
||||
# 确保data中有platform字段,如果没有则默认为'qq'
|
||||
if 'platform' not in data:
|
||||
data['platform'] = 'qq'
|
||||
|
||||
rela = Relationship(data=data)
|
||||
rela.saved = True
|
||||
self.relationships[rela.user_id] = rela
|
||||
key = (rela.user_id, rela.platform)
|
||||
self.relationships[key] = rela
|
||||
return rela
|
||||
|
||||
async def load_all_relationships(self):
|
||||
@@ -117,9 +212,7 @@ class RelationshipManager:
|
||||
all_relationships = db.db.relationships.find({})
|
||||
# 依次加载每条记录
|
||||
for data in all_relationships:
|
||||
user_id = data['user_id']
|
||||
relationship = await self.load_relationship(data)
|
||||
self.relationships[user_id] = relationship
|
||||
await self.load_relationship(data)
|
||||
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
||||
|
||||
while True:
|
||||
@@ -130,16 +223,15 @@ class RelationshipManager:
|
||||
async def _save_all_relationships(self):
|
||||
"""将所有关系数据保存到数据库"""
|
||||
# 保存所有关系数据
|
||||
for userid, relationship in self.relationships.items():
|
||||
for (userid, platform), relationship in self.relationships.items():
|
||||
if not relationship.saved:
|
||||
relationship.saved = True
|
||||
await self.storage_relationship(relationship)
|
||||
|
||||
async def storage_relationship(self,relationship: Relationship):
|
||||
"""
|
||||
将关系记录存储到数据库中
|
||||
"""
|
||||
async def storage_relationship(self, relationship: Relationship):
|
||||
"""将关系记录存储到数据库中"""
|
||||
user_id = relationship.user_id
|
||||
platform = relationship.platform
|
||||
nickname = relationship.nickname
|
||||
relationship_value = relationship.relationship_value
|
||||
gender = relationship.gender
|
||||
@@ -148,8 +240,9 @@ class RelationshipManager:
|
||||
|
||||
db = Database.get_instance()
|
||||
db.db.relationships.update_one(
|
||||
{'user_id': user_id},
|
||||
{'user_id': user_id, 'platform': platform},
|
||||
{'$set': {
|
||||
'platform': platform,
|
||||
'nickname': nickname,
|
||||
'relationship_value': relationship_value,
|
||||
'gender': gender,
|
||||
@@ -159,12 +252,35 @@ class RelationshipManager:
|
||||
upsert=True
|
||||
)
|
||||
|
||||
def get_name(self, user_id: int) -> str:
|
||||
def get_name(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None) -> str:
|
||||
"""获取用户昵称
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
Returns:
|
||||
str: 用户昵称
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
# 确保user_id是整数类型
|
||||
user_id = int(user_id)
|
||||
if user_id in self.relationships:
|
||||
|
||||
return self.relationships[user_id].nickname
|
||||
key = (user_id, platform)
|
||||
if key in self.relationships:
|
||||
return self.relationships[key].nickname
|
||||
elif user_info is not None:
|
||||
return user_info.user_nickname or user_info.user_cardname or "某人"
|
||||
else:
|
||||
return "某人"
|
||||
|
||||
|
||||
@@ -1,47 +1,28 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...common.database import Database
|
||||
from .message_cq import Message
|
||||
from .message_base import MessageBase
|
||||
from .message import MessageSending, MessageRecv
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
def __init__(self):
|
||||
self.db = Database.get_instance()
|
||||
|
||||
async def store_message(self, message: Message, topic: Optional[str] = None) -> None:
|
||||
async def store_message(self, message: Union[MessageSending, MessageRecv], topic: Optional[str] = None) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
if not message.is_emoji:
|
||||
message_data = {
|
||||
"group_id": message.group_id,
|
||||
"user_id": message.user_id,
|
||||
"message_id": message.message_id,
|
||||
"raw_message": message.raw_message,
|
||||
"plain_text": message.plain_text,
|
||||
message_data = {
|
||||
"message_id": message.message_info.message_id,
|
||||
"time": message.message_info.time,
|
||||
"group_id": message.message_info.group_info.group_id,
|
||||
"group_name": message.message_info.group_info.group_name,
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_nickname": message.message_info.user_info.user_nickname,
|
||||
"detailed_plain_text": message.detailed_plain_text,
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
"time": message.time,
|
||||
"user_nickname": message.user_nickname,
|
||||
"user_cardname": message.user_cardname,
|
||||
"group_name": message.group_name,
|
||||
"topic": topic,
|
||||
"detailed_plain_text": message.detailed_plain_text,
|
||||
}
|
||||
else:
|
||||
message_data = {
|
||||
"group_id": message.group_id,
|
||||
"user_id": message.user_id,
|
||||
"message_id": message.message_id,
|
||||
"raw_message": message.raw_message,
|
||||
"plain_text": message.plain_text,
|
||||
"processed_plain_text": '[表情包]',
|
||||
"time": message.time,
|
||||
"user_nickname": message.user_nickname,
|
||||
"user_cardname": message.user_cardname,
|
||||
"group_name": message.group_name,
|
||||
"topic": topic,
|
||||
"detailed_plain_text": message.detailed_plain_text,
|
||||
}
|
||||
|
||||
self.db.db.messages.insert_one(message_data)
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from loguru import logger
|
||||
|
||||
from .config import global_config
|
||||
from .message_base import UserInfo, GroupInfo
|
||||
from .chat_stream import chat_manager
|
||||
|
||||
|
||||
class WillingManager:
|
||||
def __init__(self):
|
||||
self.group_reply_willing = {} # 存储每个群的回复意愿
|
||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
||||
self._decay_task = None
|
||||
self._started = False
|
||||
|
||||
@@ -12,20 +17,35 @@ class WillingManager:
|
||||
"""定期衰减回复意愿"""
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
for group_id in self.group_reply_willing:
|
||||
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
|
||||
for chat_id in self.chat_reply_willing:
|
||||
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
||||
|
||||
def get_willing(self, group_id: int) -> float:
|
||||
"""获取指定群组的回复意愿"""
|
||||
return self.group_reply_willing.get(group_id, 0)
|
||||
def get_willing(self, platform: str, user_info: UserInfo, group_info: GroupInfo = None) -> float:
|
||||
"""获取指定聊天流的回复意愿"""
|
||||
stream = chat_manager.get_stream_by_info(platform, user_info, group_info)
|
||||
if stream:
|
||||
return self.chat_reply_willing.get(stream.stream_id, 0)
|
||||
return 0
|
||||
|
||||
def set_willing(self, group_id: int, willing: float):
|
||||
"""设置指定群组的回复意愿"""
|
||||
self.group_reply_willing[group_id] = willing
|
||||
def set_willing(self, chat_id: str, willing: float):
|
||||
"""设置指定聊天流的回复意愿"""
|
||||
self.chat_reply_willing[chat_id] = willing
|
||||
|
||||
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float:
|
||||
"""改变指定群组的回复意愿并返回回复概率"""
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
async def change_reply_willing_received(self,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: GroupInfo = None,
|
||||
topic: str = None,
|
||||
is_mentioned_bot: bool = False,
|
||||
config = None,
|
||||
is_emoji: bool = False,
|
||||
interested_rate: float = 0) -> float:
|
||||
"""改变指定聊天流的回复意愿并返回回复概率"""
|
||||
# 获取或创建聊天流
|
||||
stream = await chat_manager.get_or_create_stream(platform, user_info, group_info)
|
||||
chat_id = stream.stream_id
|
||||
|
||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||
|
||||
# print(f"初始意愿: {current_willing}")
|
||||
if is_mentioned_bot and current_willing < 1.0:
|
||||
@@ -49,31 +69,37 @@ class WillingManager:
|
||||
# print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
|
||||
|
||||
reply_probability = max((current_willing - 0.45) * 2, 0)
|
||||
if group_id not in config.talk_allowed_groups:
|
||||
current_willing = 0
|
||||
reply_probability = 0
|
||||
|
||||
if group_id in config.talk_frequency_down_groups:
|
||||
reply_probability = reply_probability / global_config.down_frequency_rate
|
||||
|
||||
# 检查群组权限(如果是群聊)
|
||||
if group_info:
|
||||
if group_info.group_id not in config.talk_allowed_groups:
|
||||
current_willing = 0
|
||||
reply_probability = 0
|
||||
|
||||
if group_info.group_id in config.talk_frequency_down_groups:
|
||||
reply_probability = reply_probability / global_config.down_frequency_rate
|
||||
|
||||
reply_probability = min(reply_probability, 1)
|
||||
if reply_probability < 0:
|
||||
reply_probability = 0
|
||||
|
||||
|
||||
self.group_reply_willing[group_id] = min(current_willing, 3.0)
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||
return reply_probability
|
||||
|
||||
def change_reply_willing_sent(self, group_id: int):
|
||||
"""开始思考后降低群组的回复意愿"""
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
self.group_reply_willing[group_id] = max(0, current_willing - 2)
|
||||
def change_reply_willing_sent(self, platform: str, user_info: UserInfo, group_info: GroupInfo = None):
|
||||
"""开始思考后降低聊天流的回复意愿"""
|
||||
stream = chat_manager.get_stream_by_info(platform, user_info, group_info)
|
||||
if stream:
|
||||
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
|
||||
self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2)
|
||||
|
||||
def change_reply_willing_after_sent(self, group_id: int):
|
||||
"""发送消息后提高群组的回复意愿"""
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
if current_willing < 1:
|
||||
self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
|
||||
def change_reply_willing_after_sent(self, platform: str, user_info: UserInfo, group_info: GroupInfo = None):
|
||||
"""发送消息后提高聊天流的回复意愿"""
|
||||
stream = chat_manager.get_stream_by_info(platform, user_info, group_info)
|
||||
if stream:
|
||||
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
|
||||
if current_willing < 1:
|
||||
self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2)
|
||||
|
||||
async def ensure_started(self):
|
||||
"""确保衰减任务已启动"""
|
||||
|
||||
Reference in New Issue
Block a user