refractor: 几乎写完了,进入测试阶段
This commit is contained in:
@@ -10,18 +10,19 @@ from .config import global_config
|
|||||||
from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
|
from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
|
||||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||||
from .llm_generator import ResponseGenerator
|
from .llm_generator import ResponseGenerator
|
||||||
|
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||||
from .message_cq import (
|
from .message_cq import (
|
||||||
Message,
|
MessageRecvCQ,
|
||||||
Message_Sending,
|
MessageSendCQ,
|
||||||
Message_Thinking, # 导入 Message_Thinking 类
|
|
||||||
MessageSet,
|
|
||||||
)
|
)
|
||||||
|
from .chat_stream import chat_manager
|
||||||
from .message_sender import message_manager # 导入新的消息管理器
|
from .message_sender import message_manager # 导入新的消息管理器
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
|
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
|
||||||
|
from .utils_image import image_path_to_base64
|
||||||
from .willing_manager import willing_manager # 导入意愿管理器
|
from .willing_manager import willing_manager # 导入意愿管理器
|
||||||
|
from .message_base import UserInfo, GroupInfo, Seg
|
||||||
|
|
||||||
class ChatBot:
|
class ChatBot:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -43,12 +44,9 @@ class ChatBot:
|
|||||||
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
||||||
"""处理收到的群消息"""
|
"""处理收到的群消息"""
|
||||||
|
|
||||||
if event.group_id not in global_config.talk_allowed_groups:
|
|
||||||
return
|
|
||||||
self.bot = bot # 更新 bot 实例
|
self.bot = bot # 更新 bot 实例
|
||||||
|
|
||||||
if event.user_id in global_config.ban_user_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
group_info = await bot.get_group_info(group_id=event.group_id)
|
group_info = await bot.get_group_info(group_id=event.group_id)
|
||||||
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||||
@@ -56,25 +54,46 @@ class ChatBot:
|
|||||||
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
||||||
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
||||||
|
|
||||||
message = Message(
|
message_cq=MessageRecvCQ(
|
||||||
group_id=event.group_id,
|
|
||||||
user_id=event.user_id,
|
|
||||||
message_id=event.message_id,
|
message_id=event.message_id,
|
||||||
user_cardname=sender_info['card'],
|
user_id=event.user_id,
|
||||||
raw_message=str(event.original_message),
|
raw_message=str(event.original_message),
|
||||||
plain_text=event.get_plaintext(),
|
group_id=event.group_id,
|
||||||
reply_message=event.reply,
|
reply_message=event.reply,
|
||||||
|
platform='qq'
|
||||||
)
|
)
|
||||||
await message.initialize()
|
message_json=message_cq.to_dict()
|
||||||
|
|
||||||
|
# 进入maimbot
|
||||||
|
message=MessageRecv(**message_json)
|
||||||
|
await message.process()
|
||||||
|
groupinfo=message.message_info.group_info
|
||||||
|
userinfo=message.message_info.user_info
|
||||||
|
messageinfo=message.message_info
|
||||||
|
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
|
||||||
|
|
||||||
|
# 消息过滤,涉及到config有待更新
|
||||||
|
if groupinfo:
|
||||||
|
if groupinfo.group_id not in global_config.talk_allowed_groups:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if userinfo:
|
||||||
|
if userinfo.user_id in []:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
if userinfo.user_id in global_config.ban_user_id:
|
||||||
|
return
|
||||||
# 过滤词
|
# 过滤词
|
||||||
for word in global_config.ban_words:
|
for word in global_config.ban_words:
|
||||||
if word in message.detailed_plain_text:
|
if word in message.processed_plain_text:
|
||||||
logger.info(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}")
|
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
|
||||||
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered")
|
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered")
|
||||||
return
|
return
|
||||||
|
|
||||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -85,47 +104,55 @@ class ChatBot:
|
|||||||
print(f"\033[1;32m[记忆激活]\033[0m 对{message.processed_plain_text}的激活度:---------------------------------------{interested_rate}\n")
|
print(f"\033[1;32m[记忆激活]\033[0m 对{message.processed_plain_text}的激活度:---------------------------------------{interested_rate}\n")
|
||||||
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
|
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
|
||||||
|
|
||||||
await self.storage.store_message(message, topic[0] if topic else None)
|
await self.storage.store_message(message,chat, topic[0] if topic else None)
|
||||||
|
|
||||||
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
|
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
|
||||||
reply_probability = willing_manager.change_reply_willing_received(
|
reply_probability = await willing_manager.change_reply_willing_received(
|
||||||
event.group_id,
|
chat_stream=chat,
|
||||||
topic[0] if topic else None,
|
topic=topic[0] if topic else None,
|
||||||
is_mentioned,
|
is_mentioned_bot=is_mentioned,
|
||||||
global_config,
|
config=global_config,
|
||||||
event.user_id,
|
is_emoji=message.is_emoji,
|
||||||
message.is_emoji,
|
interested_rate=interested_rate
|
||||||
interested_rate
|
)
|
||||||
|
current_willing = willing_manager.get_willing(
|
||||||
|
chat_stream=chat
|
||||||
)
|
)
|
||||||
current_willing = willing_manager.get_willing(event.group_id)
|
|
||||||
|
|
||||||
|
print(f"\033[1;32m[{current_time}][{chat.group_info.group_name}]{chat.user_info.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
|
||||||
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
|
|
||||||
|
|
||||||
response = ""
|
response = None
|
||||||
|
|
||||||
if random() < reply_probability:
|
if random() < reply_probability:
|
||||||
|
bot_user_info=UserInfo(
|
||||||
|
user_id=global_config.BOT_QQ,
|
||||||
|
user_nickname=global_config.BOT_NICKNAME,
|
||||||
|
platform=messageinfo.platform
|
||||||
|
)
|
||||||
tinking_time_point = round(time.time(), 2)
|
tinking_time_point = round(time.time(), 2)
|
||||||
think_id = 'mt' + str(tinking_time_point)
|
think_id = 'mt' + str(tinking_time_point)
|
||||||
thinking_message = Message_Thinking(message=message,message_id=think_id)
|
thinking_message = MessageThinking.from_chat_stream(
|
||||||
|
chat_stream=chat,
|
||||||
|
message_id=think_id,
|
||||||
|
reply=message
|
||||||
|
)
|
||||||
|
|
||||||
message_manager.add_message(thinking_message)
|
message_manager.add_message(thinking_message)
|
||||||
|
|
||||||
willing_manager.change_reply_willing_sent(thinking_message.group_id)
|
willing_manager.change_reply_willing_sent(
|
||||||
|
chat_stream=chat
|
||||||
|
)
|
||||||
|
|
||||||
response,raw_content = await self.gpt.generate_response(message)
|
response,raw_content = await self.gpt.generate_response(message)
|
||||||
|
|
||||||
if response:
|
if response:
|
||||||
container = message_manager.get_container(event.group_id)
|
container = message_manager.get_container(chat.stream_id)
|
||||||
thinking_message = None
|
thinking_message = None
|
||||||
# 找到message,删除
|
# 找到message,删除
|
||||||
for msg in container.messages:
|
for msg in container.messages:
|
||||||
if isinstance(msg, Message_Thinking) and msg.message_id == think_id:
|
if isinstance(msg, MessageThinking) and msg.message_info.message_id == think_id:
|
||||||
thinking_message = msg
|
thinking_message = msg
|
||||||
container.messages.remove(msg)
|
container.messages.remove(msg)
|
||||||
# print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# 如果找不到思考消息,直接返回
|
# 如果找不到思考消息,直接返回
|
||||||
@@ -135,11 +162,10 @@ class ChatBot:
|
|||||||
|
|
||||||
#记录开始思考的时间,避免从思考到回复的时间太久
|
#记录开始思考的时间,避免从思考到回复的时间太久
|
||||||
thinking_start_time = thinking_message.thinking_start_time
|
thinking_start_time = thinking_message.thinking_start_time
|
||||||
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
|
message_set = MessageSet(chat, think_id)
|
||||||
#计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
#计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
||||||
accu_typing_time = 0
|
accu_typing_time = 0
|
||||||
|
|
||||||
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
|
|
||||||
mark_head = False
|
mark_head = False
|
||||||
for msg in response:
|
for msg in response:
|
||||||
# print(f"\033[1;32m[回复内容]\033[0m {msg}")
|
# print(f"\033[1;32m[回复内容]\033[0m {msg}")
|
||||||
@@ -148,22 +174,16 @@ class ChatBot:
|
|||||||
accu_typing_time += typing_time
|
accu_typing_time += typing_time
|
||||||
timepoint = tinking_time_point + accu_typing_time
|
timepoint = tinking_time_point + accu_typing_time
|
||||||
|
|
||||||
bot_message = Message_Sending(
|
message_segment = Seg(type='text', data=msg)
|
||||||
group_id=event.group_id,
|
bot_message = MessageSending(
|
||||||
user_id=global_config.BOT_QQ,
|
|
||||||
message_id=think_id,
|
message_id=think_id,
|
||||||
raw_message=msg,
|
chat_stream=chat,
|
||||||
plain_text=msg,
|
message_segment=message_segment,
|
||||||
processed_plain_text=msg,
|
reply=message,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
is_head=not mark_head,
|
||||||
group_name=message.group_name,
|
is_emoji=False
|
||||||
time=timepoint, #记录了回复生成的时间
|
|
||||||
thinking_start_time=thinking_start_time, #记录了思考开始的时间
|
|
||||||
reply_message_id=message.message_id
|
|
||||||
)
|
)
|
||||||
await bot_message.initialize()
|
|
||||||
if not mark_head:
|
if not mark_head:
|
||||||
bot_message.is_head = True
|
|
||||||
mark_head = True
|
mark_head = True
|
||||||
message_set.add_message(bot_message)
|
message_set.add_message(bot_message)
|
||||||
|
|
||||||
@@ -180,30 +200,22 @@ class ChatBot:
|
|||||||
if emoji_raw != None:
|
if emoji_raw != None:
|
||||||
emoji_path,discription = emoji_raw
|
emoji_path,discription = emoji_raw
|
||||||
|
|
||||||
emoji_cq = cq_code_tool.create_emoji_cq(emoji_path)
|
emoji_cq = image_path_to_base64(emoji_path)
|
||||||
|
|
||||||
if random() < 0.5:
|
if random() < 0.5:
|
||||||
bot_response_time = tinking_time_point - 1
|
bot_response_time = tinking_time_point - 1
|
||||||
else:
|
else:
|
||||||
bot_response_time = bot_response_time + 1
|
bot_response_time = bot_response_time + 1
|
||||||
|
|
||||||
bot_message = Message_Sending(
|
message_segment = Seg(type='emoji', data=emoji_cq)
|
||||||
group_id=event.group_id,
|
bot_message = MessageSending(
|
||||||
user_id=global_config.BOT_QQ,
|
message_id=think_id,
|
||||||
message_id=0,
|
chat_stream=chat,
|
||||||
raw_message=emoji_cq,
|
message_segment=message_segment,
|
||||||
plain_text=emoji_cq,
|
reply=message,
|
||||||
processed_plain_text=emoji_cq,
|
is_head=False,
|
||||||
detailed_plain_text=discription,
|
is_emoji=True
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
|
||||||
group_name=message.group_name,
|
|
||||||
time=bot_response_time,
|
|
||||||
is_emoji=True,
|
|
||||||
translate_cq=False,
|
|
||||||
thinking_start_time=thinking_start_time,
|
|
||||||
# reply_message_id=message.message_id
|
|
||||||
)
|
)
|
||||||
await bot_message.initialize()
|
|
||||||
message_manager.add_message(bot_message)
|
message_manager.add_message(bot_message)
|
||||||
emotion = await self.gpt._get_emotion_tags(raw_content)
|
emotion = await self.gpt._get_emotion_tags(raw_content)
|
||||||
print(f"为 '{response}' 获取到的情感标签为:{emotion}")
|
print(f"为 '{response}' 获取到的情感标签为:{emotion}")
|
||||||
@@ -219,8 +231,12 @@ class ChatBot:
|
|||||||
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
|
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
|
||||||
# 使用情绪管理器更新情绪
|
# 使用情绪管理器更新情绪
|
||||||
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
|
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
|
||||||
|
|
||||||
# willing_manager.change_reply_willing_after_sent(event.group_id)
|
willing_manager.change_reply_willing_after_sent(
|
||||||
|
platform=messageinfo.platform,
|
||||||
|
user_info=userinfo,
|
||||||
|
group_info=groupinfo
|
||||||
|
)
|
||||||
|
|
||||||
# 创建全局ChatBot实例
|
# 创建全局ChatBot实例
|
||||||
chat_bot = ChatBot()
|
chat_bot = ChatBot()
|
||||||
209
src/plugins/chat/chat_stream.py
Normal file
209
src/plugins/chat/chat_stream.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Dict, Tuple
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from ...common.database import Database
|
||||||
|
from .message_base import UserInfo, GroupInfo
|
||||||
|
|
||||||
|
|
||||||
|
class ChatStream:
|
||||||
|
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||||
|
def __init__(self,
|
||||||
|
stream_id: str,
|
||||||
|
platform: str,
|
||||||
|
user_info: UserInfo,
|
||||||
|
group_info: Optional[GroupInfo] = None,
|
||||||
|
data: dict = None):
|
||||||
|
self.stream_id = stream_id
|
||||||
|
self.platform = platform
|
||||||
|
self.user_info = user_info
|
||||||
|
self.group_info = group_info
|
||||||
|
self.create_time = data.get('create_time', int(time.time())) if data else int(time.time())
|
||||||
|
self.last_active_time = data.get('last_active_time', self.create_time) if data else self.create_time
|
||||||
|
self.saved = False
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""转换为字典格式"""
|
||||||
|
result = {
|
||||||
|
'stream_id': self.stream_id,
|
||||||
|
'platform': self.platform,
|
||||||
|
'user_info': self.user_info.to_dict() if self.user_info else None,
|
||||||
|
'group_info': self.group_info.to_dict() if self.group_info else None,
|
||||||
|
'create_time': self.create_time,
|
||||||
|
'last_active_time': self.last_active_time
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> 'ChatStream':
|
||||||
|
"""从字典创建实例"""
|
||||||
|
user_info = UserInfo(**data.get('user_info', {})) if data.get('user_info') else None
|
||||||
|
group_info = GroupInfo(**data.get('group_info', {})) if data.get('group_info') else None
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
stream_id=data['stream_id'],
|
||||||
|
platform=data['platform'],
|
||||||
|
user_info=user_info,
|
||||||
|
group_info=group_info,
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_active_time(self):
|
||||||
|
"""更新最后活跃时间"""
|
||||||
|
self.last_active_time = int(time.time())
|
||||||
|
self.saved = False
|
||||||
|
|
||||||
|
|
||||||
|
class ChatManager:
|
||||||
|
"""聊天管理器,管理所有聊天流"""
|
||||||
|
_instance = None
|
||||||
|
_initialized = False
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not self._initialized:
|
||||||
|
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||||
|
self.db = Database.get_instance()
|
||||||
|
self._ensure_collection()
|
||||||
|
self._initialized = True
|
||||||
|
# 在事件循环中启动初始化
|
||||||
|
asyncio.create_task(self._initialize())
|
||||||
|
# 启动自动保存任务
|
||||||
|
asyncio.create_task(self._auto_save_task())
|
||||||
|
|
||||||
|
async def _initialize(self):
|
||||||
|
"""异步初始化"""
|
||||||
|
try:
|
||||||
|
await self.load_all_streams()
|
||||||
|
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"聊天管理器启动失败: {str(e)}")
|
||||||
|
|
||||||
|
async def _auto_save_task(self):
|
||||||
|
"""定期自动保存所有聊天流"""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(300) # 每5分钟保存一次
|
||||||
|
try:
|
||||||
|
await self._save_all_streams()
|
||||||
|
logger.info("聊天流自动保存完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||||
|
|
||||||
|
def _ensure_collection(self):
|
||||||
|
"""确保数据库集合存在并创建索引"""
|
||||||
|
if 'chat_streams' not in self.db.db.list_collection_names():
|
||||||
|
self.db.db.create_collection('chat_streams')
|
||||||
|
# 创建索引
|
||||||
|
self.db.db.chat_streams.create_index([('stream_id', 1)], unique=True)
|
||||||
|
self.db.db.chat_streams.create_index([
|
||||||
|
('platform', 1),
|
||||||
|
('user_info.user_id', 1),
|
||||||
|
('group_info.group_id', 1)
|
||||||
|
])
|
||||||
|
|
||||||
|
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||||
|
"""生成聊天流唯一ID"""
|
||||||
|
# 组合关键信息
|
||||||
|
components = [
|
||||||
|
platform,
|
||||||
|
str(user_info.user_id),
|
||||||
|
str(group_info.group_id) if group_info else 'private'
|
||||||
|
]
|
||||||
|
|
||||||
|
# 使用MD5生成唯一ID
|
||||||
|
key = '_'.join(components)
|
||||||
|
return hashlib.md5(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
async def get_or_create_stream(self,
|
||||||
|
platform: str,
|
||||||
|
user_info: UserInfo,
|
||||||
|
group_info: Optional[GroupInfo] = None) -> ChatStream:
|
||||||
|
"""获取或创建聊天流
|
||||||
|
|
||||||
|
Args:
|
||||||
|
platform: 平台标识
|
||||||
|
user_info: 用户信息
|
||||||
|
group_info: 群组信息(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatStream: 聊天流对象
|
||||||
|
"""
|
||||||
|
# 生成stream_id
|
||||||
|
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||||
|
|
||||||
|
# 检查内存中是否存在
|
||||||
|
if stream_id in self.streams:
|
||||||
|
stream = self.streams[stream_id]
|
||||||
|
# 更新用户信息和群组信息
|
||||||
|
stream.user_info = user_info
|
||||||
|
if group_info:
|
||||||
|
stream.group_info = group_info
|
||||||
|
stream.update_active_time()
|
||||||
|
return stream
|
||||||
|
|
||||||
|
# 检查数据库中是否存在
|
||||||
|
data = self.db.db.chat_streams.find_one({'stream_id': stream_id})
|
||||||
|
if data:
|
||||||
|
stream = ChatStream.from_dict(data)
|
||||||
|
# 更新用户信息和群组信息
|
||||||
|
stream.user_info = user_info
|
||||||
|
if group_info:
|
||||||
|
stream.group_info = group_info
|
||||||
|
stream.update_active_time()
|
||||||
|
else:
|
||||||
|
# 创建新的聊天流
|
||||||
|
stream = ChatStream(
|
||||||
|
stream_id=stream_id,
|
||||||
|
platform=platform,
|
||||||
|
user_info=user_info,
|
||||||
|
group_info=group_info
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存到内存和数据库
|
||||||
|
self.streams[stream_id] = stream
|
||||||
|
await self._save_stream(stream)
|
||||||
|
return stream
|
||||||
|
|
||||||
|
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||||
|
"""通过stream_id获取聊天流"""
|
||||||
|
return self.streams.get(stream_id)
|
||||||
|
|
||||||
|
def get_stream_by_info(self,
|
||||||
|
platform: str,
|
||||||
|
user_info: UserInfo,
|
||||||
|
group_info: Optional[GroupInfo] = None) -> Optional[ChatStream]:
|
||||||
|
"""通过信息获取聊天流"""
|
||||||
|
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||||
|
return self.streams.get(stream_id)
|
||||||
|
|
||||||
|
async def _save_stream(self, stream: ChatStream):
|
||||||
|
"""保存聊天流到数据库"""
|
||||||
|
if not stream.saved:
|
||||||
|
self.db.db.chat_streams.update_one(
|
||||||
|
{'stream_id': stream.stream_id},
|
||||||
|
{'$set': stream.to_dict()},
|
||||||
|
upsert=True
|
||||||
|
)
|
||||||
|
stream.saved = True
|
||||||
|
|
||||||
|
async def _save_all_streams(self):
|
||||||
|
"""保存所有聊天流"""
|
||||||
|
for stream in self.streams.values():
|
||||||
|
await self._save_stream(stream)
|
||||||
|
|
||||||
|
async def load_all_streams(self):
|
||||||
|
"""从数据库加载所有聊天流"""
|
||||||
|
all_streams = self.db.db.chat_streams.find({})
|
||||||
|
for data in all_streams:
|
||||||
|
stream = ChatStream.from_dict(data)
|
||||||
|
self.streams[stream.stream_id] = stream
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局单例
|
||||||
|
chat_manager = ChatManager()
|
||||||
@@ -373,6 +373,24 @@ class CQCode_tool:
|
|||||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||||
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_emoji_cq_base64(base64_data: str) -> str:
|
||||||
|
"""
|
||||||
|
创建表情包CQ码
|
||||||
|
Args:
|
||||||
|
base64_data: base64编码的表情包数据
|
||||||
|
Returns:
|
||||||
|
表情包CQ码字符串
|
||||||
|
"""
|
||||||
|
# 转义base64数据
|
||||||
|
escaped_base64 = base64_data.replace('&', '&') \
|
||||||
|
.replace('[', '[') \
|
||||||
|
.replace(']', ']') \
|
||||||
|
.replace(',', ',')
|
||||||
|
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||||
|
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ class EmojiManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记录表情使用失败: {str(e)}")
|
logger.error(f"记录表情使用失败: {str(e)}")
|
||||||
|
|
||||||
async def get_emoji_for_text(self, text: str) -> Optional[str]:
|
async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str,str]]:
|
||||||
"""根据文本内容获取相关表情包
|
"""根据文本内容获取相关表情包
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: 输入文本
|
||||||
|
|||||||
@@ -5,11 +5,10 @@ from typing import Dict, ForwardRef, List, Optional, Union
|
|||||||
import urllib3
|
import urllib3
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .cq_code import CQCode, cq_code_tool
|
|
||||||
from .utils_cq import parse_cq_code
|
|
||||||
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
|
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
|
||||||
from .utils_image import image_manager
|
from .utils_image import image_manager
|
||||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
||||||
|
from .chat_stream import ChatStream
|
||||||
# 禁用SSL警告
|
# 禁用SSL警告
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
@@ -40,6 +39,7 @@ class MessageRecv(MessageBase):
|
|||||||
# 处理消息内容
|
# 处理消息内容
|
||||||
self.processed_plain_text = "" # 初始化为空字符串
|
self.processed_plain_text = "" # 初始化为空字符串
|
||||||
self.detailed_plain_text = "" # 初始化为空字符串
|
self.detailed_plain_text = "" # 初始化为空字符串
|
||||||
|
self.is_emoji=False
|
||||||
|
|
||||||
async def process(self) -> None:
|
async def process(self) -> None:
|
||||||
"""处理消息内容,生成纯文本和详细文本
|
"""处理消息内容,生成纯文本和详细文本
|
||||||
@@ -88,6 +88,7 @@ class MessageRecv(MessageBase):
|
|||||||
return await image_manager.get_image_description(seg.data)
|
return await image_manager.get_image_description(seg.data)
|
||||||
return '[图片]'
|
return '[图片]'
|
||||||
elif seg.type == 'emoji':
|
elif seg.type == 'emoji':
|
||||||
|
self.is_emoji=True
|
||||||
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
|
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
|
||||||
return await image_manager.get_emoji_description(seg.data)
|
return await image_manager.get_emoji_description(seg.data)
|
||||||
return '[表情]'
|
return '[表情]'
|
||||||
@@ -115,36 +116,17 @@ class MessageProcessBase(MessageBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
user_id: int,
|
chat_stream: ChatStream,
|
||||||
group_id: Optional[int] = None,
|
|
||||||
platform: str = "qq",
|
|
||||||
message_segment: Optional[Seg] = None,
|
message_segment: Optional[Seg] = None,
|
||||||
reply: Optional['MessageRecv'] = None
|
reply: Optional['MessageRecv'] = None
|
||||||
):
|
):
|
||||||
# 构造用户信息
|
|
||||||
user_info = UserInfo(
|
|
||||||
platform=platform,
|
|
||||||
user_id=user_id,
|
|
||||||
user_nickname=get_user_nickname(user_id),
|
|
||||||
user_cardname=get_user_cardname(user_id) if group_id else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构造群组信息(如果有)
|
|
||||||
group_info = None
|
|
||||||
if group_id:
|
|
||||||
group_info = GroupInfo(
|
|
||||||
platform=platform,
|
|
||||||
group_id=group_id,
|
|
||||||
group_name=get_groupname(group_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构造基础消息信息
|
# 构造基础消息信息
|
||||||
message_info = BaseMessageInfo(
|
message_info = BaseMessageInfo(
|
||||||
platform=platform,
|
platform=chat_stream.platform,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
time=int(time.time()),
|
time=int(time.time()),
|
||||||
group_info=group_info,
|
group_info=chat_stream.group_info,
|
||||||
user_info=user_info
|
user_info=chat_stream.user_info
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
@@ -241,17 +223,13 @@ class MessageThinking(MessageProcessBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
user_id: int,
|
chat_stream: ChatStream,
|
||||||
group_id: Optional[int] = None,
|
|
||||||
platform: str = "qq",
|
|
||||||
reply: Optional['MessageRecv'] = None
|
reply: Optional['MessageRecv'] = None
|
||||||
):
|
):
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
user_id=user_id,
|
chat_stream=chat_stream,
|
||||||
group_id=group_id,
|
|
||||||
platform=platform,
|
|
||||||
message_segment=None, # 思考状态不需要消息段
|
message_segment=None, # 思考状态不需要消息段
|
||||||
reply=reply
|
reply=reply
|
||||||
)
|
)
|
||||||
@@ -259,6 +237,15 @@ class MessageThinking(MessageProcessBase):
|
|||||||
# 思考状态特有属性
|
# 思考状态特有属性
|
||||||
self.interrupt = False
|
self.interrupt = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_chat_stream(cls, chat_stream: ChatStream, message_id: str, reply: Optional['MessageRecv'] = None) -> 'MessageThinking':
|
||||||
|
"""从聊天流创建思考状态消息"""
|
||||||
|
return cls(
|
||||||
|
message_id=message_id,
|
||||||
|
chat_stream=chat_stream,
|
||||||
|
reply=reply
|
||||||
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageSending(MessageProcessBase):
|
class MessageSending(MessageProcessBase):
|
||||||
"""发送状态的消息类"""
|
"""发送状态的消息类"""
|
||||||
@@ -266,19 +253,16 @@ class MessageSending(MessageProcessBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
user_id: int,
|
chat_stream: ChatStream,
|
||||||
message_segment: Seg,
|
message_segment: Seg,
|
||||||
group_id: Optional[int] = None,
|
|
||||||
reply: Optional['MessageRecv'] = None,
|
reply: Optional['MessageRecv'] = None,
|
||||||
platform: str = "qq",
|
is_head: bool = False,
|
||||||
is_head: bool = False
|
is_emoji: bool = False
|
||||||
):
|
):
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(
|
super().__init__(
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
user_id=user_id,
|
chat_stream=chat_stream,
|
||||||
group_id=group_id,
|
|
||||||
platform=platform,
|
|
||||||
message_segment=message_segment,
|
message_segment=message_segment,
|
||||||
reply=reply
|
reply=reply
|
||||||
)
|
)
|
||||||
@@ -286,6 +270,12 @@ class MessageSending(MessageProcessBase):
|
|||||||
# 发送状态特有属性
|
# 发送状态特有属性
|
||||||
self.reply_to_message_id = reply.message_info.message_id if reply else None
|
self.reply_to_message_id = reply.message_info.message_id if reply else None
|
||||||
self.is_head = is_head
|
self.is_head = is_head
|
||||||
|
self.is_emoji = is_emoji
|
||||||
|
if is_head:
|
||||||
|
self.message_segment = Seg(type='seglist', data=[
|
||||||
|
Seg(type='reply', data=reply.message_info.message_id),
|
||||||
|
self.message_segment
|
||||||
|
])
|
||||||
|
|
||||||
async def process(self) -> None:
|
async def process(self) -> None:
|
||||||
"""处理消息内容,生成纯文本和详细文本"""
|
"""处理消息内容,生成纯文本和详细文本"""
|
||||||
@@ -298,26 +288,24 @@ class MessageSending(MessageProcessBase):
|
|||||||
cls,
|
cls,
|
||||||
thinking: MessageThinking,
|
thinking: MessageThinking,
|
||||||
message_segment: Seg,
|
message_segment: Seg,
|
||||||
reply: Optional['MessageRecv'] = None,
|
is_head: bool = False,
|
||||||
is_head: bool = False
|
is_emoji: bool = False
|
||||||
) -> 'MessageSending':
|
) -> 'MessageSending':
|
||||||
"""从思考状态消息创建发送状态消息"""
|
"""从思考状态消息创建发送状态消息"""
|
||||||
return cls(
|
return cls(
|
||||||
message_id=thinking.message_info.message_id,
|
message_id=thinking.message_info.message_id,
|
||||||
user_id=thinking.message_info.user_info.user_id,
|
chat_stream=thinking.chat_stream,
|
||||||
message_segment=message_segment,
|
message_segment=message_segment,
|
||||||
group_id=thinking.message_info.group_info.group_id if thinking.message_info.group_info else None,
|
reply=thinking.reply,
|
||||||
reply=reply or thinking.reply,
|
is_head=is_head,
|
||||||
platform=thinking.message_info.platform,
|
is_emoji=is_emoji
|
||||||
is_head=is_head
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageSet:
|
class MessageSet:
|
||||||
"""消息集合类,可以存储多个发送消息"""
|
"""消息集合类,可以存储多个发送消息"""
|
||||||
def __init__(self, group_id: int, user_id: int, message_id: str):
|
def __init__(self, chat_stream: ChatStream, message_id: str):
|
||||||
self.group_id = group_id
|
self.chat_stream = chat_stream
|
||||||
self.user_id = user_id
|
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
self.messages: List[MessageSending] = []
|
self.messages: List[MessageSending] = []
|
||||||
self.time = round(time.time(), 2)
|
self.time = round(time.time(), 2)
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ class MessageSendCQ(MessageCQ):
|
|||||||
elif seg.type == 'image':
|
elif seg.type == 'image':
|
||||||
# 如果是base64图片数据
|
# 如果是base64图片数据
|
||||||
if seg.data.startswith(('data:', 'base64:')):
|
if seg.data.startswith(('data:', 'base64:')):
|
||||||
return f"[CQ:image,file=base64://{seg.data}]"
|
return cq_code_tool.create_emoji_cq_base64(seg.data)
|
||||||
# 如果是表情包(本地文件)
|
# 如果是表情包(本地文件)
|
||||||
return cq_code_tool.create_emoji_cq(seg.data)
|
return cq_code_tool.create_emoji_cq(seg.data)
|
||||||
elif seg.type == 'at':
|
elif seg.type == 'at':
|
||||||
|
|||||||
@@ -5,10 +5,11 @@ from typing import Dict, List, Optional, Union
|
|||||||
from nonebot.adapters.onebot.v11 import Bot
|
from nonebot.adapters.onebot.v11 import Bot
|
||||||
|
|
||||||
from .cq_code import cq_code_tool
|
from .cq_code import cq_code_tool
|
||||||
from .message_cq import Message, Message_Sending, Message_Thinking, MessageSet
|
from .message_cq import MessageSendCQ
|
||||||
|
from .message import MessageSending, MessageThinking, MessageRecv,MessageSet
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
from .utils import calculate_typing_time
|
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
|
from .chat_stream import chat_manager
|
||||||
|
|
||||||
|
|
||||||
class Message_Sender:
|
class Message_Sender:
|
||||||
@@ -21,66 +22,59 @@ class Message_Sender:
|
|||||||
def set_bot(self, bot: Bot):
|
def set_bot(self, bot: Bot):
|
||||||
"""设置当前bot实例"""
|
"""设置当前bot实例"""
|
||||||
self._current_bot = bot
|
self._current_bot = bot
|
||||||
|
|
||||||
async def send_group_message(
|
|
||||||
self,
|
|
||||||
group_id: int,
|
|
||||||
send_text: str,
|
|
||||||
auto_escape: bool = False,
|
|
||||||
reply_message_id: int = None,
|
|
||||||
at_user_id: int = None
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
if not self._current_bot:
|
async def send_message(
|
||||||
raise RuntimeError("Bot未设置,请先调用set_bot方法设置bot实例")
|
self,
|
||||||
|
message: MessageSending,
|
||||||
message = send_text
|
) -> None:
|
||||||
|
"""发送消息"""
|
||||||
# 如果需要回复
|
if isinstance(message, MessageSending):
|
||||||
if reply_message_id:
|
message_send=MessageSendCQ(
|
||||||
reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
|
message_id=message.message_id,
|
||||||
message = reply_cq + message
|
user_id=message.message_info.user_info.user_id,
|
||||||
|
message_segment=message.message_segment,
|
||||||
# 如果需要at
|
reply=message.reply
|
||||||
# if at_user_id:
|
|
||||||
# at_cq = cq_code_tool.create_at_cq(at_user_id)
|
|
||||||
# message = at_cq + " " + message
|
|
||||||
|
|
||||||
|
|
||||||
typing_time = calculate_typing_time(message)
|
|
||||||
if typing_time > 10:
|
|
||||||
typing_time = 10
|
|
||||||
await asyncio.sleep(typing_time)
|
|
||||||
|
|
||||||
# 发送消息
|
|
||||||
try:
|
|
||||||
await self._current_bot.send_group_msg(
|
|
||||||
group_id=group_id,
|
|
||||||
message=message,
|
|
||||||
auto_escape=auto_escape
|
|
||||||
)
|
)
|
||||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
|
if message.message_info.group_info:
|
||||||
except Exception as e:
|
try:
|
||||||
print(f"发生错误 {e}")
|
await self._current_bot.send_group_msg(
|
||||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
|
group_id=message.message_info.group_info.group_id,
|
||||||
|
message=message_send.raw_message,
|
||||||
|
auto_escape=False
|
||||||
|
)
|
||||||
|
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误 {e}")
|
||||||
|
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await self._current_bot.send_private_msg(
|
||||||
|
user_id=message.message_info.user_info.user_id,
|
||||||
|
message=message_send.raw_message,
|
||||||
|
auto_escape=False
|
||||||
|
)
|
||||||
|
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误 {e}")
|
||||||
|
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
|
||||||
|
|
||||||
|
|
||||||
class MessageContainer:
|
class MessageContainer:
|
||||||
"""单个群的发送/思考消息容器"""
|
"""单个聊天流的发送/思考消息容器"""
|
||||||
def __init__(self, group_id: int, max_size: int = 100):
|
def __init__(self, chat_id: str, max_size: int = 100):
|
||||||
self.group_id = group_id
|
self.chat_id = chat_id
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.messages = []
|
self.messages = []
|
||||||
self.last_send_time = 0
|
self.last_send_time = 0
|
||||||
self.thinking_timeout = 20 # 思考超时时间(秒)
|
self.thinking_timeout = 20 # 思考超时时间(秒)
|
||||||
|
|
||||||
def get_timeout_messages(self) -> List[Message_Sending]:
|
def get_timeout_messages(self) -> List[MessageSending]:
|
||||||
"""获取所有超时的Message_Sending对象(思考时间超过30秒),按thinking_start_time排序"""
|
"""获取所有超时的Message_Sending对象(思考时间超过30秒),按thinking_start_time排序"""
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
timeout_messages = []
|
timeout_messages = []
|
||||||
|
|
||||||
for msg in self.messages:
|
for msg in self.messages:
|
||||||
if isinstance(msg, Message_Sending):
|
if isinstance(msg, MessageSending):
|
||||||
if current_time - msg.thinking_start_time > self.thinking_timeout:
|
if current_time - msg.thinking_start_time > self.thinking_timeout:
|
||||||
timeout_messages.append(msg)
|
timeout_messages.append(msg)
|
||||||
|
|
||||||
@@ -89,7 +83,7 @@ class MessageContainer:
|
|||||||
|
|
||||||
return timeout_messages
|
return timeout_messages
|
||||||
|
|
||||||
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]:
|
def get_earliest_message(self) -> Optional[Union[MessageThinking, MessageSending]]:
|
||||||
"""获取thinking_start_time最早的消息对象"""
|
"""获取thinking_start_time最早的消息对象"""
|
||||||
if not self.messages:
|
if not self.messages:
|
||||||
return None
|
return None
|
||||||
@@ -102,16 +96,15 @@ class MessageContainer:
|
|||||||
earliest_message = msg
|
earliest_message = msg
|
||||||
return earliest_message
|
return earliest_message
|
||||||
|
|
||||||
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None:
|
def add_message(self, message: Union[MessageThinking, MessageSending]) -> None:
|
||||||
"""添加消息到队列"""
|
"""添加消息到队列"""
|
||||||
# print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
|
|
||||||
if isinstance(message, MessageSet):
|
if isinstance(message, MessageSet):
|
||||||
for single_message in message.messages:
|
for single_message in message.messages:
|
||||||
self.messages.append(single_message)
|
self.messages.append(single_message)
|
||||||
else:
|
else:
|
||||||
self.messages.append(message)
|
self.messages.append(message)
|
||||||
|
|
||||||
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool:
|
def remove_message(self, message: Union[MessageThinking, MessageSending]) -> bool:
|
||||||
"""移除消息,如果消息存在则返回True,否则返回False"""
|
"""移除消息,如果消息存在则返回True,否则返回False"""
|
||||||
try:
|
try:
|
||||||
if message in self.messages:
|
if message in self.messages:
|
||||||
@@ -126,40 +119,42 @@ class MessageContainer:
|
|||||||
"""检查是否有待发送的消息"""
|
"""检查是否有待发送的消息"""
|
||||||
return bool(self.messages)
|
return bool(self.messages)
|
||||||
|
|
||||||
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]:
|
def get_all_messages(self) -> List[Union[MessageSending, MessageThinking]]:
|
||||||
"""获取所有消息"""
|
"""获取所有消息"""
|
||||||
return list(self.messages)
|
return list(self.messages)
|
||||||
|
|
||||||
|
|
||||||
class MessageManager:
|
class MessageManager:
|
||||||
"""管理所有群的消息容器"""
|
"""管理所有聊天流的消息容器"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.containers: Dict[int, MessageContainer] = {}
|
self.containers: Dict[str, MessageContainer] = {} # chat_id -> MessageContainer
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
||||||
def get_container(self, group_id: int) -> MessageContainer:
|
def get_container(self, chat_id: str) -> MessageContainer:
|
||||||
"""获取或创建群的消息容器"""
|
"""获取或创建聊天流的消息容器"""
|
||||||
if group_id not in self.containers:
|
if chat_id not in self.containers:
|
||||||
self.containers[group_id] = MessageContainer(group_id)
|
self.containers[chat_id] = MessageContainer(chat_id)
|
||||||
return self.containers[group_id]
|
return self.containers[chat_id]
|
||||||
|
|
||||||
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None:
|
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
|
||||||
container = self.get_container(message.group_id)
|
chat_stream = chat_manager.get_stream_by_info(
|
||||||
|
platform=message.message_info.platform,
|
||||||
|
user_info=message.message_info.user_info,
|
||||||
|
group_info=message.message_info.group_info
|
||||||
|
)
|
||||||
|
if not chat_stream:
|
||||||
|
raise ValueError("无法找到对应的聊天流")
|
||||||
|
container = self.get_container(chat_stream.stream_id)
|
||||||
container.add_message(message)
|
container.add_message(message)
|
||||||
|
|
||||||
async def process_group_messages(self, group_id: int):
|
async def process_chat_messages(self, chat_id: str):
|
||||||
"""处理群消息"""
|
"""处理聊天流消息"""
|
||||||
# if int(time.time() / 3) == time.time() / 3:
|
container = self.get_container(chat_id)
|
||||||
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
|
|
||||||
container = self.get_container(group_id)
|
|
||||||
if container.has_messages():
|
if container.has_messages():
|
||||||
#最早的对象,可能是思考消息,也可能是发送消息
|
message_earliest = container.get_earliest_message()
|
||||||
message_earliest = container.get_earliest_message() #一个message_thinking or message_sending
|
|
||||||
|
|
||||||
#如果是思考消息
|
if isinstance(message_earliest, MessageThinking):
|
||||||
if isinstance(message_earliest, Message_Thinking):
|
|
||||||
#优先等待这条消息
|
|
||||||
message_earliest.update_thinking_time()
|
message_earliest.update_thinking_time()
|
||||||
thinking_time = message_earliest.thinking_time
|
thinking_time = message_earliest.thinking_time
|
||||||
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}秒\033[K\r", end='', flush=True)
|
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}秒\033[K\r", end='', flush=True)
|
||||||
@@ -168,42 +163,36 @@ class MessageManager:
|
|||||||
if thinking_time > global_config.thinking_timeout:
|
if thinking_time > global_config.thinking_timeout:
|
||||||
print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息")
|
print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息")
|
||||||
container.remove_message(message_earliest)
|
container.remove_message(message_earliest)
|
||||||
else:# 如果不是message_thinking就只能是message_sending
|
else:
|
||||||
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
|
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
|
||||||
#直接发,等什么呢
|
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
|
||||||
if message_earliest.is_head and message_earliest.update_thinking_time() >30:
|
await message_sender.send_message(message_earliest)
|
||||||
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id)
|
|
||||||
else:
|
else:
|
||||||
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False)
|
await message_sender.send_message(message_earliest)
|
||||||
#移除消息
|
|
||||||
if message_earliest.is_emoji:
|
if message_earliest.is_emoji:
|
||||||
message_earliest.processed_plain_text = "[表情包]"
|
message_earliest.processed_plain_text = "[表情包]"
|
||||||
await self.storage.store_message(message_earliest, None)
|
await self.storage.store_message(message_earliest, None)
|
||||||
|
|
||||||
container.remove_message(message_earliest)
|
container.remove_message(message_earliest)
|
||||||
|
|
||||||
#获取并处理超时消息
|
message_timeout = container.get_timeout_messages()
|
||||||
message_timeout = container.get_timeout_messages() #也许是一堆message_sending
|
|
||||||
if message_timeout:
|
if message_timeout:
|
||||||
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息")
|
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息")
|
||||||
for msg in message_timeout:
|
for msg in message_timeout:
|
||||||
if msg == message_earliest:
|
if msg == message_earliest:
|
||||||
continue # 跳过已经处理过的消息
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
#发送
|
if msg.is_head and msg.update_thinking_time() > 30:
|
||||||
if msg.is_head and msg.update_thinking_time() >30:
|
await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
|
||||||
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
|
|
||||||
else:
|
else:
|
||||||
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False)
|
await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False)
|
||||||
|
|
||||||
|
|
||||||
#如果是表情包,则替换为"[表情包]"
|
|
||||||
if msg.is_emoji:
|
if msg.is_emoji:
|
||||||
msg.processed_plain_text = "[表情包]"
|
msg.processed_plain_text = "[表情包]"
|
||||||
await self.storage.store_message(msg, None)
|
await self.storage.store_message(msg, None)
|
||||||
|
|
||||||
# 安全地移除消息
|
|
||||||
if not container.remove_message(msg):
|
if not container.remove_message(msg):
|
||||||
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息")
|
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -215,8 +204,8 @@ class MessageManager:
|
|||||||
while self._running:
|
while self._running:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
tasks = []
|
tasks = []
|
||||||
for group_id in self.containers.keys():
|
for chat_id in self.containers.keys():
|
||||||
tasks.append(self.process_group_messages(group_id))
|
tasks.append(self.process_chat_messages(chat_id))
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
|
from .message_base import UserInfo
|
||||||
|
from .chat_stream import ChatStream
|
||||||
|
|
||||||
class Impression:
|
class Impression:
|
||||||
traits: str = None
|
traits: str = None
|
||||||
@@ -13,60 +14,77 @@ class Impression:
|
|||||||
|
|
||||||
class Relationship:
|
class Relationship:
|
||||||
user_id: int = None
|
user_id: int = None
|
||||||
# impression: Impression = None
|
platform: str = None
|
||||||
# group_id: int = None
|
|
||||||
# group_name: str = None
|
|
||||||
gender: str = None
|
gender: str = None
|
||||||
age: int = None
|
age: int = None
|
||||||
nickname: str = None
|
nickname: str = None
|
||||||
relationship_value: float = None
|
relationship_value: float = None
|
||||||
saved = False
|
saved = False
|
||||||
|
|
||||||
def __init__(self, user_id: int, data=None, **kwargs):
|
def __init__(self, chat:ChatStream,data:dict):
|
||||||
if isinstance(data, dict):
|
self.user_id=chat.user_info.user_id
|
||||||
# 如果输入是字典,使用字典解析
|
self.platform=chat.platform
|
||||||
self.user_id = data.get('user_id')
|
self.nickname=chat.user_info.user_nickname
|
||||||
self.gender = data.get('gender')
|
self.relationship_value=data.get('relationship_value',0)
|
||||||
self.age = data.get('age')
|
self.age=data.get('age',0)
|
||||||
self.nickname = data.get('nickname')
|
self.gender=data.get('gender','')
|
||||||
self.relationship_value = data.get('relationship_value', 0.0)
|
|
||||||
self.saved = data.get('saved', False)
|
|
||||||
else:
|
|
||||||
# 如果是直接传入属性值
|
|
||||||
self.user_id = kwargs.get('user_id')
|
|
||||||
self.gender = kwargs.get('gender')
|
|
||||||
self.age = kwargs.get('age')
|
|
||||||
self.nickname = kwargs.get('nickname')
|
|
||||||
self.relationship_value = kwargs.get('relationship_value', 0.0)
|
|
||||||
self.saved = kwargs.get('saved', False)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RelationshipManager:
|
class RelationshipManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.relationships: dict[int, Relationship] = {}
|
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
||||||
|
|
||||||
async def update_relationship(self, user_id: int, data=None, **kwargs):
|
async def update_relationship(self,
|
||||||
|
chat_stream:ChatStream,
|
||||||
|
data: dict = None,
|
||||||
|
**kwargs) -> Optional[Relationship]:
|
||||||
|
"""更新或创建关系
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||||
|
platform: 平台(可选,如果提供user_info则不需要)
|
||||||
|
user_info: 用户信息对象(可选)
|
||||||
|
data: 字典格式的数据(可选)
|
||||||
|
**kwargs: 其他参数
|
||||||
|
Returns:
|
||||||
|
Relationship: 关系对象
|
||||||
|
"""
|
||||||
|
# 确定user_id和platform
|
||||||
|
if chat_stream.user_info is not None:
|
||||||
|
user_id = chat_stream.user_info.user_id
|
||||||
|
platform = chat_stream.user_info.platform or 'qq'
|
||||||
|
else:
|
||||||
|
platform = platform or 'qq'
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
raise ValueError("必须提供user_id或user_info")
|
||||||
|
|
||||||
|
# 使用(user_id, platform)作为键
|
||||||
|
key = (user_id, platform)
|
||||||
|
|
||||||
# 检查是否在内存中已存在
|
# 检查是否在内存中已存在
|
||||||
relationship = self.relationships.get(user_id)
|
relationship = self.relationships.get(key)
|
||||||
if relationship:
|
if relationship:
|
||||||
# 如果存在,更新现有对象
|
# 如果存在,更新现有对象
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
for key, value in data.items():
|
for k, value in data.items():
|
||||||
if hasattr(relationship, key) and value is not None:
|
if hasattr(relationship, k) and value is not None:
|
||||||
setattr(relationship, key, value)
|
setattr(relationship, k, value)
|
||||||
else:
|
else:
|
||||||
for key, value in kwargs.items():
|
for k, value in kwargs.items():
|
||||||
if hasattr(relationship, key) and value is not None:
|
if hasattr(relationship, k) and value is not None:
|
||||||
setattr(relationship, key, value)
|
setattr(relationship, k, value)
|
||||||
else:
|
else:
|
||||||
# 如果不存在,创建新对象
|
# 如果不存在,创建新对象
|
||||||
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id, **kwargs)
|
if user_info is not None:
|
||||||
self.relationships[user_id] = relationship
|
relationship = Relationship(user_info=user_info, **kwargs)
|
||||||
|
elif isinstance(data, dict):
|
||||||
# 更新 id_name_nickname_table
|
data['platform'] = platform
|
||||||
# self.id_name_nickname_table[user_id] = [relationship.nickname] # 别称设置为空列表
|
relationship = Relationship(user_id=user_id, data=data)
|
||||||
|
else:
|
||||||
|
kwargs['platform'] = platform
|
||||||
|
kwargs['user_id'] = user_id
|
||||||
|
relationship = Relationship(**kwargs)
|
||||||
|
self.relationships[key] = relationship
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
await self.storage_relationship(relationship)
|
await self.storage_relationship(relationship)
|
||||||
@@ -74,33 +92,87 @@ class RelationshipManager:
|
|||||||
|
|
||||||
return relationship
|
return relationship
|
||||||
|
|
||||||
async def update_relationship_value(self, user_id: int, **kwargs):
|
async def update_relationship_value(self,
|
||||||
|
user_id: int = None,
|
||||||
|
platform: str = None,
|
||||||
|
user_info: UserInfo = None,
|
||||||
|
**kwargs) -> Optional[Relationship]:
|
||||||
|
"""更新关系值
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||||
|
platform: 平台(可选,如果提供user_info则不需要)
|
||||||
|
user_info: 用户信息对象(可选)
|
||||||
|
**kwargs: 其他参数
|
||||||
|
Returns:
|
||||||
|
Relationship: 关系对象
|
||||||
|
"""
|
||||||
|
# 确定user_id和platform
|
||||||
|
if user_info is not None:
|
||||||
|
user_id = user_info.user_id
|
||||||
|
platform = user_info.platform or 'qq'
|
||||||
|
else:
|
||||||
|
platform = platform or 'qq'
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
raise ValueError("必须提供user_id或user_info")
|
||||||
|
|
||||||
|
# 使用(user_id, platform)作为键
|
||||||
|
key = (user_id, platform)
|
||||||
|
|
||||||
# 检查是否在内存中已存在
|
# 检查是否在内存中已存在
|
||||||
relationship = self.relationships.get(user_id)
|
relationship = self.relationships.get(key)
|
||||||
if relationship:
|
if relationship:
|
||||||
for key, value in kwargs.items():
|
for k, value in kwargs.items():
|
||||||
if key == 'relationship_value':
|
if k == 'relationship_value':
|
||||||
relationship.relationship_value += value
|
relationship.relationship_value += value
|
||||||
await self.storage_relationship(relationship)
|
await self.storage_relationship(relationship)
|
||||||
relationship.saved = True
|
relationship.saved = True
|
||||||
return relationship
|
return relationship
|
||||||
else:
|
else:
|
||||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id} 不存在,无法更新")
|
# 如果不存在且提供了user_info,则创建新的关系
|
||||||
|
if user_info is not None:
|
||||||
|
return await self.update_relationship(user_info=user_info, **kwargs)
|
||||||
|
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_relationship(self,
|
||||||
def get_relationship(self, user_id: int) -> Optional[Relationship]:
|
user_id: int = None,
|
||||||
"""获取用户关系对象"""
|
platform: str = None,
|
||||||
if user_id in self.relationships:
|
user_info: UserInfo = None) -> Optional[Relationship]:
|
||||||
return self.relationships[user_id]
|
"""获取用户关系对象
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||||
|
platform: 平台(可选,如果提供user_info则不需要)
|
||||||
|
user_info: 用户信息对象(可选)
|
||||||
|
Returns:
|
||||||
|
Relationship: 关系对象
|
||||||
|
"""
|
||||||
|
# 确定user_id和platform
|
||||||
|
if user_info is not None:
|
||||||
|
user_id = user_info.user_id
|
||||||
|
platform = user_info.platform or 'qq'
|
||||||
|
else:
|
||||||
|
platform = platform or 'qq'
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
raise ValueError("必须提供user_id或user_info")
|
||||||
|
|
||||||
|
key = (user_id, platform)
|
||||||
|
if key in self.relationships:
|
||||||
|
return self.relationships[key]
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def load_relationship(self, data: dict) -> Relationship:
|
async def load_relationship(self, data: dict) -> Relationship:
|
||||||
"""从数据库加载或创建新的关系对象"""
|
"""从数据库加载或创建新的关系对象"""
|
||||||
rela = Relationship(user_id=data['user_id'], data=data)
|
# 确保data中有platform字段,如果没有则默认为'qq'
|
||||||
|
if 'platform' not in data:
|
||||||
|
data['platform'] = 'qq'
|
||||||
|
|
||||||
|
rela = Relationship(data=data)
|
||||||
rela.saved = True
|
rela.saved = True
|
||||||
self.relationships[rela.user_id] = rela
|
key = (rela.user_id, rela.platform)
|
||||||
|
self.relationships[key] = rela
|
||||||
return rela
|
return rela
|
||||||
|
|
||||||
async def load_all_relationships(self):
|
async def load_all_relationships(self):
|
||||||
@@ -117,9 +189,7 @@ class RelationshipManager:
|
|||||||
all_relationships = db.db.relationships.find({})
|
all_relationships = db.db.relationships.find({})
|
||||||
# 依次加载每条记录
|
# 依次加载每条记录
|
||||||
for data in all_relationships:
|
for data in all_relationships:
|
||||||
user_id = data['user_id']
|
await self.load_relationship(data)
|
||||||
relationship = await self.load_relationship(data)
|
|
||||||
self.relationships[user_id] = relationship
|
|
||||||
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -130,16 +200,15 @@ class RelationshipManager:
|
|||||||
async def _save_all_relationships(self):
|
async def _save_all_relationships(self):
|
||||||
"""将所有关系数据保存到数据库"""
|
"""将所有关系数据保存到数据库"""
|
||||||
# 保存所有关系数据
|
# 保存所有关系数据
|
||||||
for userid, relationship in self.relationships.items():
|
for (userid, platform), relationship in self.relationships.items():
|
||||||
if not relationship.saved:
|
if not relationship.saved:
|
||||||
relationship.saved = True
|
relationship.saved = True
|
||||||
await self.storage_relationship(relationship)
|
await self.storage_relationship(relationship)
|
||||||
|
|
||||||
async def storage_relationship(self,relationship: Relationship):
|
async def storage_relationship(self, relationship: Relationship):
|
||||||
"""
|
"""将关系记录存储到数据库中"""
|
||||||
将关系记录存储到数据库中
|
|
||||||
"""
|
|
||||||
user_id = relationship.user_id
|
user_id = relationship.user_id
|
||||||
|
platform = relationship.platform
|
||||||
nickname = relationship.nickname
|
nickname = relationship.nickname
|
||||||
relationship_value = relationship.relationship_value
|
relationship_value = relationship.relationship_value
|
||||||
gender = relationship.gender
|
gender = relationship.gender
|
||||||
@@ -148,8 +217,9 @@ class RelationshipManager:
|
|||||||
|
|
||||||
db = Database.get_instance()
|
db = Database.get_instance()
|
||||||
db.db.relationships.update_one(
|
db.db.relationships.update_one(
|
||||||
{'user_id': user_id},
|
{'user_id': user_id, 'platform': platform},
|
||||||
{'$set': {
|
{'$set': {
|
||||||
|
'platform': platform,
|
||||||
'nickname': nickname,
|
'nickname': nickname,
|
||||||
'relationship_value': relationship_value,
|
'relationship_value': relationship_value,
|
||||||
'gender': gender,
|
'gender': gender,
|
||||||
@@ -159,12 +229,35 @@ class RelationshipManager:
|
|||||||
upsert=True
|
upsert=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_name(self, user_id: int) -> str:
|
def get_name(self,
|
||||||
|
user_id: int = None,
|
||||||
|
platform: str = None,
|
||||||
|
user_info: UserInfo = None) -> str:
|
||||||
|
"""获取用户昵称
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||||
|
platform: 平台(可选,如果提供user_info则不需要)
|
||||||
|
user_info: 用户信息对象(可选)
|
||||||
|
Returns:
|
||||||
|
str: 用户昵称
|
||||||
|
"""
|
||||||
|
# 确定user_id和platform
|
||||||
|
if user_info is not None:
|
||||||
|
user_id = user_info.user_id
|
||||||
|
platform = user_info.platform or 'qq'
|
||||||
|
else:
|
||||||
|
platform = platform or 'qq'
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
raise ValueError("必须提供user_id或user_info")
|
||||||
|
|
||||||
# 确保user_id是整数类型
|
# 确保user_id是整数类型
|
||||||
user_id = int(user_id)
|
user_id = int(user_id)
|
||||||
if user_id in self.relationships:
|
key = (user_id, platform)
|
||||||
|
if key in self.relationships:
|
||||||
return self.relationships[user_id].nickname
|
return self.relationships[key].nickname
|
||||||
|
elif user_info is not None:
|
||||||
|
return user_info.user_nickname or user_info.user_cardname or "某人"
|
||||||
else:
|
else:
|
||||||
return "某人"
|
return "某人"
|
||||||
|
|
||||||
|
|||||||
@@ -1,47 +1,26 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from .message_cq import Message
|
from .message_base import MessageBase
|
||||||
|
from .message import MessageSending, MessageRecv
|
||||||
|
from .chat_stream import ChatStream
|
||||||
|
|
||||||
class MessageStorage:
|
class MessageStorage:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
|
|
||||||
async def store_message(self, message: Message, topic: Optional[str] = None) -> None:
|
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
|
||||||
"""存储消息到数据库"""
|
"""存储消息到数据库"""
|
||||||
try:
|
try:
|
||||||
if not message.is_emoji:
|
message_data = {
|
||||||
message_data = {
|
"message_id": message.message_info.message_id,
|
||||||
"group_id": message.group_id,
|
"time": message.message_info.time,
|
||||||
"user_id": message.user_id,
|
"chat_id":chat_stream.stream_id,
|
||||||
"message_id": message.message_id,
|
"chat_info": chat_stream.to_dict(),
|
||||||
"raw_message": message.raw_message,
|
"detailed_plain_text": message.detailed_plain_text,
|
||||||
"plain_text": message.plain_text,
|
|
||||||
"processed_plain_text": message.processed_plain_text,
|
"processed_plain_text": message.processed_plain_text,
|
||||||
"time": message.time,
|
|
||||||
"user_nickname": message.user_nickname,
|
|
||||||
"user_cardname": message.user_cardname,
|
|
||||||
"group_name": message.group_name,
|
|
||||||
"topic": topic,
|
"topic": topic,
|
||||||
"detailed_plain_text": message.detailed_plain_text,
|
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
message_data = {
|
|
||||||
"group_id": message.group_id,
|
|
||||||
"user_id": message.user_id,
|
|
||||||
"message_id": message.message_id,
|
|
||||||
"raw_message": message.raw_message,
|
|
||||||
"plain_text": message.plain_text,
|
|
||||||
"processed_plain_text": '[表情包]',
|
|
||||||
"time": message.time,
|
|
||||||
"user_nickname": message.user_nickname,
|
|
||||||
"user_cardname": message.user_cardname,
|
|
||||||
"group_name": message.group_name,
|
|
||||||
"topic": topic,
|
|
||||||
"detailed_plain_text": message.detailed_plain_text,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.db.db.messages.insert_one(message_data)
|
self.db.db.messages.insert_one(message_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
|
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Dict
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
|
from .message_base import UserInfo, GroupInfo
|
||||||
|
from .chat_stream import chat_manager,ChatStream
|
||||||
|
|
||||||
|
|
||||||
class WillingManager:
|
class WillingManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.group_reply_willing = {} # 存储每个群的回复意愿
|
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
||||||
self._decay_task = None
|
self._decay_task = None
|
||||||
self._started = False
|
self._started = False
|
||||||
|
|
||||||
@@ -12,20 +17,33 @@ class WillingManager:
|
|||||||
"""定期衰减回复意愿"""
|
"""定期衰减回复意愿"""
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
for group_id in self.group_reply_willing:
|
for chat_id in self.chat_reply_willing:
|
||||||
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
|
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
||||||
|
|
||||||
def get_willing(self, group_id: int) -> float:
|
def get_willing(self,chat_stream:ChatStream) -> float:
|
||||||
"""获取指定群组的回复意愿"""
|
"""获取指定聊天流的回复意愿"""
|
||||||
return self.group_reply_willing.get(group_id, 0)
|
stream = chat_stream
|
||||||
|
if stream:
|
||||||
|
return self.chat_reply_willing.get(stream.stream_id, 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
def set_willing(self, group_id: int, willing: float):
|
def set_willing(self, chat_id: str, willing: float):
|
||||||
"""设置指定群组的回复意愿"""
|
"""设置指定聊天流的回复意愿"""
|
||||||
self.group_reply_willing[group_id] = willing
|
self.chat_reply_willing[chat_id] = willing
|
||||||
|
|
||||||
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float:
|
async def change_reply_willing_received(self,
|
||||||
"""改变指定群组的回复意愿并返回回复概率"""
|
chat_stream:ChatStream,
|
||||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
topic: str = None,
|
||||||
|
is_mentioned_bot: bool = False,
|
||||||
|
config = None,
|
||||||
|
is_emoji: bool = False,
|
||||||
|
interested_rate: float = 0) -> float:
|
||||||
|
"""改变指定聊天流的回复意愿并返回回复概率"""
|
||||||
|
# 获取或创建聊天流
|
||||||
|
stream = chat_stream
|
||||||
|
chat_id = stream.stream_id
|
||||||
|
|
||||||
|
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||||
|
|
||||||
# print(f"初始意愿: {current_willing}")
|
# print(f"初始意愿: {current_willing}")
|
||||||
if is_mentioned_bot and current_willing < 1.0:
|
if is_mentioned_bot and current_willing < 1.0:
|
||||||
@@ -49,31 +67,37 @@ class WillingManager:
|
|||||||
# print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
|
# print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
|
||||||
|
|
||||||
reply_probability = max((current_willing - 0.45) * 2, 0)
|
reply_probability = max((current_willing - 0.45) * 2, 0)
|
||||||
if group_id not in config.talk_allowed_groups:
|
|
||||||
current_willing = 0
|
# 检查群组权限(如果是群聊)
|
||||||
reply_probability = 0
|
if chat_stream.group_info:
|
||||||
|
if chat_stream.group_info.group_id not in config.talk_allowed_groups:
|
||||||
if group_id in config.talk_frequency_down_groups:
|
current_willing = 0
|
||||||
reply_probability = reply_probability / global_config.down_frequency_rate
|
reply_probability = 0
|
||||||
|
|
||||||
|
if chat_stream.group_info.group_id in config.talk_frequency_down_groups:
|
||||||
|
reply_probability = reply_probability / global_config.down_frequency_rate
|
||||||
|
|
||||||
reply_probability = min(reply_probability, 1)
|
reply_probability = min(reply_probability, 1)
|
||||||
if reply_probability < 0:
|
if reply_probability < 0:
|
||||||
reply_probability = 0
|
reply_probability = 0
|
||||||
|
|
||||||
|
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||||
self.group_reply_willing[group_id] = min(current_willing, 3.0)
|
|
||||||
return reply_probability
|
return reply_probability
|
||||||
|
|
||||||
def change_reply_willing_sent(self, group_id: int):
|
def change_reply_willing_sent(self, chat_stream:ChatStream):
|
||||||
"""开始思考后降低群组的回复意愿"""
|
"""开始思考后降低聊天流的回复意愿"""
|
||||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
stream = chat_stream
|
||||||
self.group_reply_willing[group_id] = max(0, current_willing - 2)
|
if stream:
|
||||||
|
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
|
||||||
|
self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2)
|
||||||
|
|
||||||
def change_reply_willing_after_sent(self, group_id: int):
|
def change_reply_willing_after_sent(self,chat_stream:ChatStream):
|
||||||
"""发送消息后提高群组的回复意愿"""
|
"""发送消息后提高聊天流的回复意愿"""
|
||||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
stream = chat_stream
|
||||||
if current_willing < 1:
|
if stream:
|
||||||
self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
|
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
|
||||||
|
if current_willing < 1:
|
||||||
|
self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2)
|
||||||
|
|
||||||
async def ensure_started(self):
|
async def ensure_started(self):
|
||||||
"""确保衰减任务已启动"""
|
"""确保衰减任务已启动"""
|
||||||
|
|||||||
Reference in New Issue
Block a user