refractor: 几乎写完了,进入测试阶段

This commit is contained in:
tcmofashi
2025-03-09 22:12:10 +08:00
parent fe3684736a
commit 6e2ea8261b
10 changed files with 654 additions and 338 deletions

View File

@@ -10,18 +10,19 @@ from .config import global_config
from .cq_code import CQCode,cq_code_tool # 导入CQCode模块 from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器 from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator from .llm_generator import ResponseGenerator
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message_cq import ( from .message_cq import (
Message, MessageRecvCQ,
Message_Sending, MessageSendCQ,
Message_Thinking, # 导入 Message_Thinking 类
MessageSet,
) )
from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器 from .message_sender import message_manager # 导入新的消息管理器
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .storage import MessageStorage from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_txt from .utils import calculate_typing_time, is_mentioned_bot_in_txt
from .utils_image import image_path_to_base64
from .willing_manager import willing_manager # 导入意愿管理器 from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg
class ChatBot: class ChatBot:
def __init__(self): def __init__(self):
@@ -43,12 +44,9 @@ class ChatBot:
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None: async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的群消息""" """处理收到的群消息"""
if event.group_id not in global_config.talk_allowed_groups:
return
self.bot = bot # 更新 bot 实例 self.bot = bot # 更新 bot 实例
if event.user_id in global_config.ban_user_id:
return
group_info = await bot.get_group_info(group_id=event.group_id) group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
@@ -56,25 +54,46 @@ class ChatBot:
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info) await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5) await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
message = Message( message_cq=MessageRecvCQ(
group_id=event.group_id,
user_id=event.user_id,
message_id=event.message_id, message_id=event.message_id,
user_cardname=sender_info['card'], user_id=event.user_id,
raw_message=str(event.original_message), raw_message=str(event.original_message),
plain_text=event.get_plaintext(), group_id=event.group_id,
reply_message=event.reply, reply_message=event.reply,
platform='qq'
) )
await message.initialize() message_json=message_cq.to_dict()
# 进入maimbot
message=MessageRecv(**message_json)
await message.process()
groupinfo=message.message_info.group_info
userinfo=message.message_info.user_info
messageinfo=message.message_info
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
# 消息过滤涉及到config有待更新
if groupinfo:
if groupinfo.group_id not in global_config.talk_allowed_groups:
return
else:
if userinfo:
if userinfo.user_id in []:
pass
else:
return
else:
return
if userinfo.user_id in global_config.ban_user_id:
return
# 过滤词 # 过滤词
for word in global_config.ban_words: for word in global_config.ban_words:
if word in message.detailed_plain_text: if word in message.processed_plain_text:
logger.info(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}") logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered") logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered")
return return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
@@ -85,47 +104,55 @@ class ChatBot:
print(f"\033[1;32m[记忆激活]\033[0m 对{message.processed_plain_text}的激活度:---------------------------------------{interested_rate}\n") print(f"\033[1;32m[记忆激活]\033[0m 对{message.processed_plain_text}的激活度:---------------------------------------{interested_rate}\n")
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") # logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
await self.storage.store_message(message, topic[0] if topic else None) await self.storage.store_message(message,chat, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text) is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
reply_probability = willing_manager.change_reply_willing_received( reply_probability = await willing_manager.change_reply_willing_received(
event.group_id, chat_stream=chat,
topic[0] if topic else None, topic=topic[0] if topic else None,
is_mentioned, is_mentioned_bot=is_mentioned,
global_config, config=global_config,
event.user_id, is_emoji=message.is_emoji,
message.is_emoji, interested_rate=interested_rate
interested_rate )
current_willing = willing_manager.get_willing(
chat_stream=chat
) )
current_willing = willing_manager.get_willing(event.group_id)
print(f"\033[1;32m[{current_time}][{chat.group_info.group_name}]{chat.user_info.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m") response = None
response = ""
if random() < reply_probability: if random() < reply_probability:
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
platform=messageinfo.platform
)
tinking_time_point = round(time.time(), 2) tinking_time_point = round(time.time(), 2)
think_id = 'mt' + str(tinking_time_point) think_id = 'mt' + str(tinking_time_point)
thinking_message = Message_Thinking(message=message,message_id=think_id) thinking_message = MessageThinking.from_chat_stream(
chat_stream=chat,
message_id=think_id,
reply=message
)
message_manager.add_message(thinking_message) message_manager.add_message(thinking_message)
willing_manager.change_reply_willing_sent(thinking_message.group_id) willing_manager.change_reply_willing_sent(
chat_stream=chat
)
response,raw_content = await self.gpt.generate_response(message) response,raw_content = await self.gpt.generate_response(message)
if response: if response:
container = message_manager.get_container(event.group_id) container = message_manager.get_container(chat.stream_id)
thinking_message = None thinking_message = None
# 找到message,删除 # 找到message,删除
for msg in container.messages: for msg in container.messages:
if isinstance(msg, Message_Thinking) and msg.message_id == think_id: if isinstance(msg, MessageThinking) and msg.message_info.message_id == think_id:
thinking_message = msg thinking_message = msg
container.messages.remove(msg) container.messages.remove(msg)
# print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
break break
# 如果找不到思考消息,直接返回 # 如果找不到思考消息,直接返回
@@ -135,11 +162,10 @@ class ChatBot:
#记录开始思考的时间,避免从思考到回复的时间太久 #记录开始思考的时间,避免从思考到回复的时间太久
thinking_start_time = thinking_message.thinking_start_time thinking_start_time = thinking_message.thinking_start_time
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) # 发送消息的id和产生发送消息的message_thinking是一致的 message_set = MessageSet(chat, think_id)
#计算打字时间1是为了模拟打字2是避免多条回复乱序 #计算打字时间1是为了模拟打字2是避免多条回复乱序
accu_typing_time = 0 accu_typing_time = 0
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
mark_head = False mark_head = False
for msg in response: for msg in response:
# print(f"\033[1;32m[回复内容]\033[0m {msg}") # print(f"\033[1;32m[回复内容]\033[0m {msg}")
@@ -148,22 +174,16 @@ class ChatBot:
accu_typing_time += typing_time accu_typing_time += typing_time
timepoint = tinking_time_point + accu_typing_time timepoint = tinking_time_point + accu_typing_time
bot_message = Message_Sending( message_segment = Seg(type='text', data=msg)
group_id=event.group_id, bot_message = MessageSending(
user_id=global_config.BOT_QQ,
message_id=think_id, message_id=think_id,
raw_message=msg, chat_stream=chat,
plain_text=msg, message_segment=message_segment,
processed_plain_text=msg, reply=message,
user_nickname=global_config.BOT_NICKNAME, is_head=not mark_head,
group_name=message.group_name, is_emoji=False
time=timepoint, #记录了回复生成的时间
thinking_start_time=thinking_start_time, #记录了思考开始的时间
reply_message_id=message.message_id
) )
await bot_message.initialize()
if not mark_head: if not mark_head:
bot_message.is_head = True
mark_head = True mark_head = True
message_set.add_message(bot_message) message_set.add_message(bot_message)
@@ -180,30 +200,22 @@ class ChatBot:
if emoji_raw != None: if emoji_raw != None:
emoji_path,discription = emoji_raw emoji_path,discription = emoji_raw
emoji_cq = cq_code_tool.create_emoji_cq(emoji_path) emoji_cq = image_path_to_base64(emoji_path)
if random() < 0.5: if random() < 0.5:
bot_response_time = tinking_time_point - 1 bot_response_time = tinking_time_point - 1
else: else:
bot_response_time = bot_response_time + 1 bot_response_time = bot_response_time + 1
bot_message = Message_Sending( message_segment = Seg(type='emoji', data=emoji_cq)
group_id=event.group_id, bot_message = MessageSending(
user_id=global_config.BOT_QQ, message_id=think_id,
message_id=0, chat_stream=chat,
raw_message=emoji_cq, message_segment=message_segment,
plain_text=emoji_cq, reply=message,
processed_plain_text=emoji_cq, is_head=False,
detailed_plain_text=discription, is_emoji=True
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=bot_response_time,
is_emoji=True,
translate_cq=False,
thinking_start_time=thinking_start_time,
# reply_message_id=message.message_id
) )
await bot_message.initialize()
message_manager.add_message(bot_message) message_manager.add_message(bot_message)
emotion = await self.gpt._get_emotion_tags(raw_content) emotion = await self.gpt._get_emotion_tags(raw_content)
print(f"'{response}' 获取到的情感标签为:{emotion}") print(f"'{response}' 获取到的情感标签为:{emotion}")
@@ -220,7 +232,11 @@ class ChatBot:
# 使用情绪管理器更新情绪 # 使用情绪管理器更新情绪
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor) self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
# willing_manager.change_reply_willing_after_sent(event.group_id) willing_manager.change_reply_willing_after_sent(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo
)
# 创建全局ChatBot实例 # 创建全局ChatBot实例
chat_bot = ChatBot() chat_bot = ChatBot()

View File

@@ -0,0 +1,209 @@
import time
import asyncio
from typing import Optional, Dict, Tuple
import hashlib
from loguru import logger
from ...common.database import Database
from .message_base import UserInfo, GroupInfo
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
def __init__(self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = data.get('create_time', int(time.time())) if data else int(time.time())
self.last_active_time = data.get('last_active_time', self.create_time) if data else self.create_time
self.saved = False
def to_dict(self) -> dict:
"""转换为字典格式"""
result = {
'stream_id': self.stream_id,
'platform': self.platform,
'user_info': self.user_info.to_dict() if self.user_info else None,
'group_info': self.group_info.to_dict() if self.group_info else None,
'create_time': self.create_time,
'last_active_time': self.last_active_time
}
return result
@classmethod
def from_dict(cls, data: dict) -> 'ChatStream':
"""从字典创建实例"""
user_info = UserInfo(**data.get('user_info', {})) if data.get('user_info') else None
group_info = GroupInfo(**data.get('group_info', {})) if data.get('group_info') else None
return cls(
stream_id=data['stream_id'],
platform=data['platform'],
user_info=user_info,
group_info=group_info,
data=data
)
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = int(time.time())
self.saved = False
class ChatManager:
"""聊天管理器,管理所有聊天流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.db = Database.get_instance()
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
asyncio.create_task(self._initialize())
# 启动自动保存任务
asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
try:
await self.load_all_streams()
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
while True:
await asyncio.sleep(300) # 每5分钟保存一次
try:
await self._save_all_streams()
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if 'chat_streams' not in self.db.db.list_collection_names():
self.db.db.create_collection('chat_streams')
# 创建索引
self.db.db.chat_streams.create_index([('stream_id', 1)], unique=True)
self.db.db.chat_streams.create_index([
('platform', 1),
('user_info.user_id', 1),
('group_info.group_id', 1)
])
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID"""
# 组合关键信息
components = [
platform,
str(user_info.user_id),
str(group_info.group_id) if group_info else 'private'
]
# 使用MD5生成唯一ID
key = '_'.join(components)
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(self,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None) -> ChatStream:
"""获取或创建聊天流
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息(可选)
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({'stream_id': stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
else:
# 创建新的聊天流
stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
def get_stream_by_info(self,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None) -> Optional[ChatStream]:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
{'stream_id': stream.stream_id},
{'$set': stream.to_dict()},
upsert=True
)
stream.saved = True
async def _save_all_streams(self):
"""保存所有聊天流"""
for stream in self.streams.values():
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream
# 创建全局单例
chat_manager = ChatManager()

View File

@@ -373,6 +373,24 @@ class CQCode_tool:
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@staticmethod
def create_emoji_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = base64_data.replace('&', '&amp;') \
.replace('[', '&#91;') \
.replace(']', '&#93;') \
.replace(',', '&#44;')
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"

View File

@@ -3,7 +3,7 @@ import os
import random import random
import time import time
import traceback import traceback
from typing import Optional from typing import Optional, Tuple
import base64 import base64
import hashlib import hashlib
@@ -92,7 +92,7 @@ class EmojiManager:
except Exception as e: except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}") logger.error(f"记录表情使用失败: {str(e)}")
async def get_emoji_for_text(self, text: str) -> Optional[str]: async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str,str]]:
"""根据文本内容获取相关表情包 """根据文本内容获取相关表情包
Args: Args:
text: 输入文本 text: 输入文本

View File

@@ -5,11 +5,10 @@ from typing import Dict, ForwardRef, List, Optional, Union
import urllib3 import urllib3
from loguru import logger from loguru import logger
from .cq_code import CQCode, cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname from .utils_user import get_groupname, get_user_cardname, get_user_nickname
from .utils_image import image_manager from .utils_image import image_manager
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream
# 禁用SSL警告 # 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -40,6 +39,7 @@ class MessageRecv(MessageBase):
# 处理消息内容 # 处理消息内容
self.processed_plain_text = "" # 初始化为空字符串 self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串 self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji=False
async def process(self) -> None: async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本 """处理消息内容,生成纯文本和详细文本
@@ -88,6 +88,7 @@ class MessageRecv(MessageBase):
return await image_manager.get_image_description(seg.data) return await image_manager.get_image_description(seg.data)
return '[图片]' return '[图片]'
elif seg.type == 'emoji': elif seg.type == 'emoji':
self.is_emoji=True
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')): if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
return await image_manager.get_emoji_description(seg.data) return await image_manager.get_emoji_description(seg.data)
return '[表情]' return '[表情]'
@@ -115,36 +116,17 @@ class MessageProcessBase(MessageBase):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
user_id: int, chat_stream: ChatStream,
group_id: Optional[int] = None,
platform: str = "qq",
message_segment: Optional[Seg] = None, message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None reply: Optional['MessageRecv'] = None
): ):
# 构造用户信息
user_info = UserInfo(
platform=platform,
user_id=user_id,
user_nickname=get_user_nickname(user_id),
user_cardname=get_user_cardname(user_id) if group_id else None
)
# 构造群组信息(如果有)
group_info = None
if group_id:
group_info = GroupInfo(
platform=platform,
group_id=group_id,
group_name=get_groupname(group_id)
)
# 构造基础消息信息 # 构造基础消息信息
message_info = BaseMessageInfo( message_info = BaseMessageInfo(
platform=platform, platform=chat_stream.platform,
message_id=message_id, message_id=message_id,
time=int(time.time()), time=int(time.time()),
group_info=group_info, group_info=chat_stream.group_info,
user_info=user_info user_info=chat_stream.user_info
) )
# 调用父类初始化 # 调用父类初始化
@@ -241,17 +223,13 @@ class MessageThinking(MessageProcessBase):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
user_id: int, chat_stream: ChatStream,
group_id: Optional[int] = None,
platform: str = "qq",
reply: Optional['MessageRecv'] = None reply: Optional['MessageRecv'] = None
): ):
# 调用父类初始化 # 调用父类初始化
super().__init__( super().__init__(
message_id=message_id, message_id=message_id,
user_id=user_id, chat_stream=chat_stream,
group_id=group_id,
platform=platform,
message_segment=None, # 思考状态不需要消息段 message_segment=None, # 思考状态不需要消息段
reply=reply reply=reply
) )
@@ -259,6 +237,15 @@ class MessageThinking(MessageProcessBase):
# 思考状态特有属性 # 思考状态特有属性
self.interrupt = False self.interrupt = False
@classmethod
def from_chat_stream(cls, chat_stream: ChatStream, message_id: str, reply: Optional['MessageRecv'] = None) -> 'MessageThinking':
"""从聊天流创建思考状态消息"""
return cls(
message_id=message_id,
chat_stream=chat_stream,
reply=reply
)
@dataclass @dataclass
class MessageSending(MessageProcessBase): class MessageSending(MessageProcessBase):
"""发送状态的消息类""" """发送状态的消息类"""
@@ -266,19 +253,16 @@ class MessageSending(MessageProcessBase):
def __init__( def __init__(
self, self,
message_id: str, message_id: str,
user_id: int, chat_stream: ChatStream,
message_segment: Seg, message_segment: Seg,
group_id: Optional[int] = None,
reply: Optional['MessageRecv'] = None, reply: Optional['MessageRecv'] = None,
platform: str = "qq", is_head: bool = False,
is_head: bool = False is_emoji: bool = False
): ):
# 调用父类初始化 # 调用父类初始化
super().__init__( super().__init__(
message_id=message_id, message_id=message_id,
user_id=user_id, chat_stream=chat_stream,
group_id=group_id,
platform=platform,
message_segment=message_segment, message_segment=message_segment,
reply=reply reply=reply
) )
@@ -286,6 +270,12 @@ class MessageSending(MessageProcessBase):
# 发送状态特有属性 # 发送状态特有属性
self.reply_to_message_id = reply.message_info.message_id if reply else None self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head self.is_head = is_head
self.is_emoji = is_emoji
if is_head:
self.message_segment = Seg(type='seglist', data=[
Seg(type='reply', data=reply.message_info.message_id),
self.message_segment
])
async def process(self) -> None: async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本""" """处理消息内容,生成纯文本和详细文本"""
@@ -298,26 +288,24 @@ class MessageSending(MessageProcessBase):
cls, cls,
thinking: MessageThinking, thinking: MessageThinking,
message_segment: Seg, message_segment: Seg,
reply: Optional['MessageRecv'] = None, is_head: bool = False,
is_head: bool = False is_emoji: bool = False
) -> 'MessageSending': ) -> 'MessageSending':
"""从思考状态消息创建发送状态消息""" """从思考状态消息创建发送状态消息"""
return cls( return cls(
message_id=thinking.message_info.message_id, message_id=thinking.message_info.message_id,
user_id=thinking.message_info.user_info.user_id, chat_stream=thinking.chat_stream,
message_segment=message_segment, message_segment=message_segment,
group_id=thinking.message_info.group_info.group_id if thinking.message_info.group_info else None, reply=thinking.reply,
reply=reply or thinking.reply, is_head=is_head,
platform=thinking.message_info.platform, is_emoji=is_emoji
is_head=is_head
) )
@dataclass @dataclass
class MessageSet: class MessageSet:
"""消息集合类,可以存储多个发送消息""" """消息集合类,可以存储多个发送消息"""
def __init__(self, group_id: int, user_id: int, message_id: str): def __init__(self, chat_stream: ChatStream, message_id: str):
self.group_id = group_id self.chat_stream = chat_stream
self.user_id = user_id
self.message_id = message_id self.message_id = message_id
self.messages: List[MessageSending] = [] self.messages: List[MessageSending] = []
self.time = round(time.time(), 2) self.time = round(time.time(), 2)

View File

@@ -176,7 +176,7 @@ class MessageSendCQ(MessageCQ):
elif seg.type == 'image': elif seg.type == 'image':
# 如果是base64图片数据 # 如果是base64图片数据
if seg.data.startswith(('data:', 'base64:')): if seg.data.startswith(('data:', 'base64:')):
return f"[CQ:image,file=base64://{seg.data}]" return cq_code_tool.create_emoji_cq_base64(seg.data)
# 如果是表情包(本地文件) # 如果是表情包(本地文件)
return cq_code_tool.create_emoji_cq(seg.data) return cq_code_tool.create_emoji_cq(seg.data)
elif seg.type == 'at': elif seg.type == 'at':

View File

@@ -5,10 +5,11 @@ from typing import Dict, List, Optional, Union
from nonebot.adapters.onebot.v11 import Bot from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool from .cq_code import cq_code_tool
from .message_cq import Message, Message_Sending, Message_Thinking, MessageSet from .message_cq import MessageSendCQ
from .message import MessageSending, MessageThinking, MessageRecv,MessageSet
from .storage import MessageStorage from .storage import MessageStorage
from .utils import calculate_typing_time
from .config import global_config from .config import global_config
from .chat_stream import chat_manager
class Message_Sender: class Message_Sender:
@@ -22,65 +23,58 @@ class Message_Sender:
"""设置当前bot实例""" """设置当前bot实例"""
self._current_bot = bot self._current_bot = bot
async def send_group_message( async def send_message(
self, self,
group_id: int, message: MessageSending,
send_text: str,
auto_escape: bool = False,
reply_message_id: int = None,
at_user_id: int = None
) -> None: ) -> None:
"""发送消息"""
if not self._current_bot: if isinstance(message, MessageSending):
raise RuntimeError("Bot未设置请先调用set_bot方法设置bot实例") message_send=MessageSendCQ(
message_id=message.message_id,
message = send_text user_id=message.message_info.user_info.user_id,
message_segment=message.message_segment,
# 如果需要回复 reply=message.reply
if reply_message_id:
reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
message = reply_cq + message
# 如果需要at
# if at_user_id:
# at_cq = cq_code_tool.create_at_cq(at_user_id)
# message = at_cq + " " + message
typing_time = calculate_typing_time(message)
if typing_time > 10:
typing_time = 10
await asyncio.sleep(typing_time)
# 发送消息
try:
await self._current_bot.send_group_msg(
group_id=group_id,
message=message,
auto_escape=auto_escape
) )
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功") if message.message_info.group_info:
except Exception as e: try:
print(f"发生错误 {e}") await self._current_bot.send_group_msg(
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败") group_id=message.message_info.group_info.group_id,
message=message_send.raw_message,
auto_escape=False
)
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
except Exception as e:
print(f"发生错误 {e}")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
else:
try:
await self._current_bot.send_private_msg(
user_id=message.message_info.user_info.user_id,
message=message_send.raw_message,
auto_escape=False
)
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
except Exception as e:
print(f"发生错误 {e}")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
class MessageContainer: class MessageContainer:
"""单个的发送/思考消息容器""" """单个聊天流的发送/思考消息容器"""
def __init__(self, group_id: int, max_size: int = 100): def __init__(self, chat_id: str, max_size: int = 100):
self.group_id = group_id self.chat_id = chat_id
self.max_size = max_size self.max_size = max_size
self.messages = [] self.messages = []
self.last_send_time = 0 self.last_send_time = 0
self.thinking_timeout = 20 # 思考超时时间(秒) self.thinking_timeout = 20 # 思考超时时间(秒)
def get_timeout_messages(self) -> List[Message_Sending]: def get_timeout_messages(self) -> List[MessageSending]:
"""获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序""" """获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序"""
current_time = time.time() current_time = time.time()
timeout_messages = [] timeout_messages = []
for msg in self.messages: for msg in self.messages:
if isinstance(msg, Message_Sending): if isinstance(msg, MessageSending):
if current_time - msg.thinking_start_time > self.thinking_timeout: if current_time - msg.thinking_start_time > self.thinking_timeout:
timeout_messages.append(msg) timeout_messages.append(msg)
@@ -89,7 +83,7 @@ class MessageContainer:
return timeout_messages return timeout_messages
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]: def get_earliest_message(self) -> Optional[Union[MessageThinking, MessageSending]]:
"""获取thinking_start_time最早的消息对象""" """获取thinking_start_time最早的消息对象"""
if not self.messages: if not self.messages:
return None return None
@@ -102,16 +96,15 @@ class MessageContainer:
earliest_message = msg earliest_message = msg
return earliest_message return earliest_message
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None: def add_message(self, message: Union[MessageThinking, MessageSending]) -> None:
"""添加消息到队列""" """添加消息到队列"""
# print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
if isinstance(message, MessageSet): if isinstance(message, MessageSet):
for single_message in message.messages: for single_message in message.messages:
self.messages.append(single_message) self.messages.append(single_message)
else: else:
self.messages.append(message) self.messages.append(message)
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool: def remove_message(self, message: Union[MessageThinking, MessageSending]) -> bool:
"""移除消息如果消息存在则返回True否则返回False""" """移除消息如果消息存在则返回True否则返回False"""
try: try:
if message in self.messages: if message in self.messages:
@@ -126,40 +119,42 @@ class MessageContainer:
"""检查是否有待发送的消息""" """检查是否有待发送的消息"""
return bool(self.messages) return bool(self.messages)
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]: def get_all_messages(self) -> List[Union[MessageSending, MessageThinking]]:
"""获取所有消息""" """获取所有消息"""
return list(self.messages) return list(self.messages)
class MessageManager: class MessageManager:
"""管理所有的消息容器""" """管理所有聊天流的消息容器"""
def __init__(self): def __init__(self):
self.containers: Dict[int, MessageContainer] = {} self.containers: Dict[str, MessageContainer] = {} # chat_id -> MessageContainer
self.storage = MessageStorage() self.storage = MessageStorage()
self._running = True self._running = True
def get_container(self, group_id: int) -> MessageContainer: def get_container(self, chat_id: str) -> MessageContainer:
"""获取或创建的消息容器""" """获取或创建聊天流的消息容器"""
if group_id not in self.containers: if chat_id not in self.containers:
self.containers[group_id] = MessageContainer(group_id) self.containers[chat_id] = MessageContainer(chat_id)
return self.containers[group_id] return self.containers[chat_id]
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None: def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
container = self.get_container(message.group_id) chat_stream = chat_manager.get_stream_by_info(
platform=message.message_info.platform,
user_info=message.message_info.user_info,
group_info=message.message_info.group_info
)
if not chat_stream:
raise ValueError("无法找到对应的聊天流")
container = self.get_container(chat_stream.stream_id)
container.add_message(message) container.add_message(message)
async def process_group_messages(self, group_id: int): async def process_chat_messages(self, chat_id: str):
"""处理消息""" """处理聊天流消息"""
# if int(time.time() / 3) == time.time() / 3: container = self.get_container(chat_id)
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
container = self.get_container(group_id)
if container.has_messages(): if container.has_messages():
#最早的对象,可能是思考消息,也可能是发送消息 message_earliest = container.get_earliest_message()
message_earliest = container.get_earliest_message() #一个message_thinking or message_sending
#如果是思考消息 if isinstance(message_earliest, MessageThinking):
if isinstance(message_earliest, Message_Thinking):
#优先等待这条消息
message_earliest.update_thinking_time() message_earliest.update_thinking_time()
thinking_time = message_earliest.thinking_time thinking_time = message_earliest.thinking_time
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}\033[K\r", end='', flush=True) print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}\033[K\r", end='', flush=True)
@@ -168,42 +163,36 @@ class MessageManager:
if thinking_time > global_config.thinking_timeout: if thinking_time > global_config.thinking_timeout:
print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息") print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息")
container.remove_message(message_earliest) container.remove_message(message_earliest)
else:# 如果不是message_thinking就只能是message_sending else:
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中") print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
#直接发,等什么呢 if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
if message_earliest.is_head and message_earliest.update_thinking_time() >30: await message_sender.send_message(message_earliest)
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id)
else: else:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False) await message_sender.send_message(message_earliest)
#移除消息
if message_earliest.is_emoji: if message_earliest.is_emoji:
message_earliest.processed_plain_text = "[表情包]" message_earliest.processed_plain_text = "[表情包]"
await self.storage.store_message(message_earliest, None) await self.storage.store_message(message_earliest, None)
container.remove_message(message_earliest) container.remove_message(message_earliest)
#获取并处理超时消息 message_timeout = container.get_timeout_messages()
message_timeout = container.get_timeout_messages() #也许是一堆message_sending
if message_timeout: if message_timeout:
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息") print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息")
for msg in message_timeout: for msg in message_timeout:
if msg == message_earliest: if msg == message_earliest:
continue # 跳过已经处理过的消息 continue
try: try:
#发送 if msg.is_head and msg.update_thinking_time() > 30:
if msg.is_head and msg.update_thinking_time() >30: await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
else: else:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False) await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False)
#如果是表情包,则替换为"[表情包]"
if msg.is_emoji: if msg.is_emoji:
msg.processed_plain_text = "[表情包]" msg.processed_plain_text = "[表情包]"
await self.storage.store_message(msg, None) await self.storage.store_message(msg, None)
# 安全地移除消息
if not container.remove_message(msg): if not container.remove_message(msg):
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息") print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息")
except Exception as e: except Exception as e:
@@ -215,8 +204,8 @@ class MessageManager:
while self._running: while self._running:
await asyncio.sleep(1) await asyncio.sleep(1)
tasks = [] tasks = []
for group_id in self.containers.keys(): for chat_id in self.containers.keys():
tasks.append(self.process_group_messages(group_id)) tasks.append(self.process_chat_messages(chat_id))
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View File

@@ -1,8 +1,9 @@
import asyncio import asyncio
from typing import Optional from typing import Optional, Union
from ...common.database import Database from ...common.database import Database
from .message_base import UserInfo
from .chat_stream import ChatStream
class Impression: class Impression:
traits: str = None traits: str = None
@@ -13,60 +14,77 @@ class Impression:
class Relationship: class Relationship:
user_id: int = None user_id: int = None
# impression: Impression = None platform: str = None
# group_id: int = None
# group_name: str = None
gender: str = None gender: str = None
age: int = None age: int = None
nickname: str = None nickname: str = None
relationship_value: float = None relationship_value: float = None
saved = False saved = False
def __init__(self, user_id: int, data=None, **kwargs): def __init__(self, chat:ChatStream,data:dict):
if isinstance(data, dict): self.user_id=chat.user_info.user_id
# 如果输入是字典,使用字典解析 self.platform=chat.platform
self.user_id = data.get('user_id') self.nickname=chat.user_info.user_nickname
self.gender = data.get('gender') self.relationship_value=data.get('relationship_value',0)
self.age = data.get('age') self.age=data.get('age',0)
self.nickname = data.get('nickname') self.gender=data.get('gender','')
self.relationship_value = data.get('relationship_value', 0.0)
self.saved = data.get('saved', False)
else:
# 如果是直接传入属性值
self.user_id = kwargs.get('user_id')
self.gender = kwargs.get('gender')
self.age = kwargs.get('age')
self.nickname = kwargs.get('nickname')
self.relationship_value = kwargs.get('relationship_value', 0.0)
self.saved = kwargs.get('saved', False)
class RelationshipManager: class RelationshipManager:
def __init__(self): def __init__(self):
self.relationships: dict[int, Relationship] = {} self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
async def update_relationship(self,
chat_stream:ChatStream,
data: dict = None,
**kwargs) -> Optional[Relationship]:
"""更新或创建关系
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
data: 字典格式的数据(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if chat_stream.user_info is not None:
user_id = chat_stream.user_info.user_id
platform = chat_stream.user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
async def update_relationship(self, user_id: int, data=None, **kwargs):
# 检查是否在内存中已存在 # 检查是否在内存中已存在
relationship = self.relationships.get(user_id) relationship = self.relationships.get(key)
if relationship: if relationship:
# 如果存在,更新现有对象 # 如果存在,更新现有对象
if isinstance(data, dict): if isinstance(data, dict):
for key, value in data.items(): for k, value in data.items():
if hasattr(relationship, key) and value is not None: if hasattr(relationship, k) and value is not None:
setattr(relationship, key, value) setattr(relationship, k, value)
else: else:
for key, value in kwargs.items(): for k, value in kwargs.items():
if hasattr(relationship, key) and value is not None: if hasattr(relationship, k) and value is not None:
setattr(relationship, key, value) setattr(relationship, k, value)
else: else:
# 如果不存在,创建新对象 # 如果不存在,创建新对象
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id, **kwargs) if user_info is not None:
self.relationships[user_id] = relationship relationship = Relationship(user_info=user_info, **kwargs)
elif isinstance(data, dict):
# 更新 id_name_nickname_table data['platform'] = platform
# self.id_name_nickname_table[user_id] = [relationship.nickname] # 别称设置为空列表 relationship = Relationship(user_id=user_id, data=data)
else:
kwargs['platform'] = platform
kwargs['user_id'] = user_id
relationship = Relationship(**kwargs)
self.relationships[key] = relationship
# 保存到数据库 # 保存到数据库
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
@@ -74,33 +92,87 @@ class RelationshipManager:
return relationship return relationship
async def update_relationship_value(self, user_id: int, **kwargs): async def update_relationship_value(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None,
**kwargs) -> Optional[Relationship]:
"""更新关系值
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
# 检查是否在内存中已存在 # 检查是否在内存中已存在
relationship = self.relationships.get(user_id) relationship = self.relationships.get(key)
if relationship: if relationship:
for key, value in kwargs.items(): for k, value in kwargs.items():
if key == 'relationship_value': if k == 'relationship_value':
relationship.relationship_value += value relationship.relationship_value += value
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
relationship.saved = True relationship.saved = True
return relationship return relationship
else: else:
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id} 不存在,无法更新") # 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(user_info=user_info, **kwargs)
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
return None return None
def get_relationship(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> Optional[Relationship]:
"""获取用户关系对象
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
def get_relationship(self, user_id: int) -> Optional[Relationship]: if user_id is None:
"""获取用户关系对象""" raise ValueError("必须提供user_id或user_info")
if user_id in self.relationships:
return self.relationships[user_id] key = (user_id, platform)
if key in self.relationships:
return self.relationships[key]
else: else:
return 0 return 0
async def load_relationship(self, data: dict) -> Relationship: async def load_relationship(self, data: dict) -> Relationship:
"""从数据库加载或创建新的关系对象""" """从数据库加载或创建新的关系对象"""
rela = Relationship(user_id=data['user_id'], data=data) # 确保data中有platform字段如果没有则默认为'qq'
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data)
rela.saved = True rela.saved = True
self.relationships[rela.user_id] = rela key = (rela.user_id, rela.platform)
self.relationships[key] = rela
return rela return rela
async def load_all_relationships(self): async def load_all_relationships(self):
@@ -117,9 +189,7 @@ class RelationshipManager:
all_relationships = db.db.relationships.find({}) all_relationships = db.db.relationships.find({})
# 依次加载每条记录 # 依次加载每条记录
for data in all_relationships: for data in all_relationships:
user_id = data['user_id'] await self.load_relationship(data)
relationship = await self.load_relationship(data)
self.relationships[user_id] = relationship
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录") print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
while True: while True:
@@ -130,16 +200,15 @@ class RelationshipManager:
async def _save_all_relationships(self): async def _save_all_relationships(self):
"""将所有关系数据保存到数据库""" """将所有关系数据保存到数据库"""
# 保存所有关系数据 # 保存所有关系数据
for userid, relationship in self.relationships.items(): for (userid, platform), relationship in self.relationships.items():
if not relationship.saved: if not relationship.saved:
relationship.saved = True relationship.saved = True
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
async def storage_relationship(self,relationship: Relationship): async def storage_relationship(self, relationship: Relationship):
""" """将关系记录存储到数据库中"""
将关系记录存储到数据库中
"""
user_id = relationship.user_id user_id = relationship.user_id
platform = relationship.platform
nickname = relationship.nickname nickname = relationship.nickname
relationship_value = relationship.relationship_value relationship_value = relationship.relationship_value
gender = relationship.gender gender = relationship.gender
@@ -148,8 +217,9 @@ class RelationshipManager:
db = Database.get_instance() db = Database.get_instance()
db.db.relationships.update_one( db.db.relationships.update_one(
{'user_id': user_id}, {'user_id': user_id, 'platform': platform},
{'$set': { {'$set': {
'platform': platform,
'nickname': nickname, 'nickname': nickname,
'relationship_value': relationship_value, 'relationship_value': relationship_value,
'gender': gender, 'gender': gender,
@@ -159,12 +229,35 @@ class RelationshipManager:
upsert=True upsert=True
) )
def get_name(self, user_id: int) -> str: def get_name(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> str:
"""获取用户昵称
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
str: 用户昵称
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 确保user_id是整数类型 # 确保user_id是整数类型
user_id = int(user_id) user_id = int(user_id)
if user_id in self.relationships: key = (user_id, platform)
if key in self.relationships:
return self.relationships[user_id].nickname return self.relationships[key].nickname
elif user_info is not None:
return user_info.user_nickname or user_info.user_cardname or "某人"
else: else:
return "某人" return "某人"

View File

@@ -1,47 +1,26 @@
from typing import Optional from typing import Optional, Union
from ...common.database import Database from ...common.database import Database
from .message_cq import Message from .message_base import MessageBase
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
class MessageStorage: class MessageStorage:
def __init__(self): def __init__(self):
self.db = Database.get_instance() self.db = Database.get_instance()
async def store_message(self, message: Message, topic: Optional[str] = None) -> None: async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
try: try:
if not message.is_emoji: message_data = {
message_data = { "message_id": message.message_info.message_id,
"group_id": message.group_id, "time": message.message_info.time,
"user_id": message.user_id, "chat_id":chat_stream.stream_id,
"message_id": message.message_id, "chat_info": chat_stream.to_dict(),
"raw_message": message.raw_message, "detailed_plain_text": message.detailed_plain_text,
"plain_text": message.plain_text,
"processed_plain_text": message.processed_plain_text, "processed_plain_text": message.processed_plain_text,
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic, "topic": topic,
"detailed_plain_text": message.detailed_plain_text,
} }
else:
message_data = {
"group_id": message.group_id,
"user_id": message.user_id,
"message_id": message.message_id,
"raw_message": message.raw_message,
"plain_text": message.plain_text,
"processed_plain_text": '[表情包]',
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic,
"detailed_plain_text": message.detailed_plain_text,
}
self.db.db.messages.insert_one(message_data) self.db.db.messages.insert_one(message_data)
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}") print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")

View File

@@ -1,10 +1,15 @@
import asyncio import asyncio
from typing import Dict
from loguru import logger
from .config import global_config from .config import global_config
from .message_base import UserInfo, GroupInfo
from .chat_stream import chat_manager,ChatStream
class WillingManager: class WillingManager:
def __init__(self): def __init__(self):
self.group_reply_willing = {} # 存储每个的回复意愿 self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self._decay_task = None self._decay_task = None
self._started = False self._started = False
@@ -12,20 +17,33 @@ class WillingManager:
"""定期衰减回复意愿""" """定期衰减回复意愿"""
while True: while True:
await asyncio.sleep(5) await asyncio.sleep(5)
for group_id in self.group_reply_willing: for chat_id in self.chat_reply_willing:
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6) self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
def get_willing(self, group_id: int) -> float: def get_willing(self,chat_stream:ChatStream) -> float:
"""获取指定群组的回复意愿""" """获取指定聊天流的回复意愿"""
return self.group_reply_willing.get(group_id, 0) stream = chat_stream
if stream:
return self.chat_reply_willing.get(stream.stream_id, 0)
return 0
def set_willing(self, group_id: int, willing: float): def set_willing(self, chat_id: str, willing: float):
"""设置指定群组的回复意愿""" """设置指定聊天流的回复意愿"""
self.group_reply_willing[group_id] = willing self.chat_reply_willing[chat_id] = willing
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float: async def change_reply_willing_received(self,
"""改变指定群组的回复意愿并返回回复概率""" chat_stream:ChatStream,
current_willing = self.group_reply_willing.get(group_id, 0) topic: str = None,
is_mentioned_bot: bool = False,
config = None,
is_emoji: bool = False,
interested_rate: float = 0) -> float:
"""改变指定聊天流的回复意愿并返回回复概率"""
# 获取或创建聊天流
stream = chat_stream
chat_id = stream.stream_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
# print(f"初始意愿: {current_willing}") # print(f"初始意愿: {current_willing}")
if is_mentioned_bot and current_willing < 1.0: if is_mentioned_bot and current_willing < 1.0:
@@ -49,31 +67,37 @@ class WillingManager:
# print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}") # print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
reply_probability = max((current_willing - 0.45) * 2, 0) reply_probability = max((current_willing - 0.45) * 2, 0)
if group_id not in config.talk_allowed_groups:
current_willing = 0
reply_probability = 0
if group_id in config.talk_frequency_down_groups: # 检查群组权限(如果是群聊)
reply_probability = reply_probability / global_config.down_frequency_rate if chat_stream.group_info:
if chat_stream.group_info.group_id not in config.talk_allowed_groups:
current_willing = 0
reply_probability = 0
if chat_stream.group_info.group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / global_config.down_frequency_rate
reply_probability = min(reply_probability, 1) reply_probability = min(reply_probability, 1)
if reply_probability < 0: if reply_probability < 0:
reply_probability = 0 reply_probability = 0
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
self.group_reply_willing[group_id] = min(current_willing, 3.0)
return reply_probability return reply_probability
def change_reply_willing_sent(self, group_id: int): def change_reply_willing_sent(self, chat_stream:ChatStream):
"""开始思考后降低群组的回复意愿""" """开始思考后降低聊天流的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0) stream = chat_stream
self.group_reply_willing[group_id] = max(0, current_willing - 2) if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2)
def change_reply_willing_after_sent(self, group_id: int): def change_reply_willing_after_sent(self,chat_stream:ChatStream):
"""发送消息后提高群组的回复意愿""" """发送消息后提高聊天流的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0) stream = chat_stream
if current_willing < 1: if stream:
self.group_reply_willing[group_id] = min(1, current_willing + 0.2) current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
if current_willing < 1:
self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2)
async def ensure_started(self): async def ensure_started(self):
"""确保衰减任务已启动""" """确保衰减任务已启动"""