Merge branch 'refractor' of https://github.com/tcmofashi/MaiMBot into refractor

This commit is contained in:
tcmofashi
2025-03-10 21:01:06 +08:00
4 changed files with 183 additions and 0 deletions

View File

@@ -11,18 +11,25 @@ from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器 from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator from .llm_generator import ResponseGenerator
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message_cq import ( from .message_cq import (
MessageRecvCQ, MessageRecvCQ,
MessageSendCQ, MessageSendCQ,
) )
from .chat_stream import chat_manager
MessageRecvCQ,
MessageSendCQ,
)
from .chat_stream import chat_manager from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器 from .message_sender import message_manager # 导入新的消息管理器
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .storage import MessageStorage from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_txt from .utils import calculate_typing_time, is_mentioned_bot_in_txt
from .utils_image import image_path_to_base64 from .utils_image import image_path_to_base64
from .utils_image import image_path_to_base64
from .willing_manager import willing_manager # 导入意愿管理器 from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg from .message_base import UserInfo, GroupInfo, Seg
from .message_base import UserInfo, GroupInfo, Seg
class ChatBot: class ChatBot:
def __init__(self): def __init__(self):
@@ -48,17 +55,22 @@ class ChatBot:
group_info = await bot.get_group_info(group_id=event.group_id) group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info) await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5) await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
message_cq=MessageRecvCQ(
message_cq=MessageRecvCQ( message_cq=MessageRecvCQ(
message_id=event.message_id, message_id=event.message_id,
user_id=event.user_id, user_id=event.user_id,
raw_message=str(event.original_message), raw_message=str(event.original_message),
group_id=event.group_id, group_id=event.group_id,
user_id=event.user_id,
raw_message=str(event.original_message),
group_id=event.group_id,
reply_message=event.reply, reply_message=event.reply,
platform='qq' platform='qq'
) )
@@ -88,12 +100,15 @@ class ChatBot:
return return
# 过滤词 # 过滤词
for word in global_config.ban_words: for word in global_config.ban_words:
if word in message.processed_plain_text:
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
if word in message.processed_plain_text: if word in message.processed_plain_text:
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}") logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered") logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered")
return return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time)) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
@@ -124,6 +139,11 @@ class ChatBot:
response = None response = None
if random() < reply_probability: if random() < reply_probability:
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
platform=messageinfo.platform
)
bot_user_info=UserInfo( bot_user_info=UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.BOT_NICKNAME,
@@ -136,6 +156,11 @@ class ChatBot:
message_id=think_id, message_id=think_id,
reply=message reply=message
) )
thinking_message = MessageThinking.from_chat_stream(
chat_stream=chat,
message_id=think_id,
reply=message
)
message_manager.add_message(thinking_message) message_manager.add_message(thinking_message)
@@ -146,6 +171,7 @@ class ChatBot:
response,raw_content = await self.gpt.generate_response(message) response,raw_content = await self.gpt.generate_response(message)
if response: if response:
container = message_manager.get_container(chat.stream_id)
container = message_manager.get_container(chat.stream_id) container = message_manager.get_container(chat.stream_id)
thinking_message = None thinking_message = None
# 找到message,删除 # 找到message,删除
@@ -163,6 +189,7 @@ class ChatBot:
#记录开始思考的时间,避免从思考到回复的时间太久 #记录开始思考的时间,避免从思考到回复的时间太久
thinking_start_time = thinking_message.thinking_start_time thinking_start_time = thinking_message.thinking_start_time
message_set = MessageSet(chat, think_id) message_set = MessageSet(chat, think_id)
message_set = MessageSet(chat, think_id)
#计算打字时间1是为了模拟打字2是避免多条回复乱序 #计算打字时间1是为了模拟打字2是避免多条回复乱序
accu_typing_time = 0 accu_typing_time = 0
@@ -174,6 +201,8 @@ class ChatBot:
accu_typing_time += typing_time accu_typing_time += typing_time
timepoint = tinking_time_point + accu_typing_time timepoint = tinking_time_point + accu_typing_time
message_segment = Seg(type='text', data=msg)
bot_message = MessageSending(
message_segment = Seg(type='text', data=msg) message_segment = Seg(type='text', data=msg)
bot_message = MessageSending( bot_message = MessageSending(
message_id=think_id, message_id=think_id,
@@ -182,6 +211,12 @@ class ChatBot:
reply=message, reply=message,
is_head=not mark_head, is_head=not mark_head,
is_emoji=False is_emoji=False
)
chat_stream=chat,
message_segment=message_segment,
reply=message,
is_head=not mark_head,
is_emoji=False
) )
if not mark_head: if not mark_head:
mark_head = True mark_head = True
@@ -200,6 +235,7 @@ class ChatBot:
if emoji_raw != None: if emoji_raw != None:
emoji_path,discription = emoji_raw emoji_path,discription = emoji_raw
emoji_cq = image_path_to_base64(emoji_path)
emoji_cq = image_path_to_base64(emoji_path) emoji_cq = image_path_to_base64(emoji_path)
if random() < 0.5: if random() < 0.5:
@@ -207,6 +243,15 @@ class ChatBot:
else: else:
bot_response_time = bot_response_time + 1 bot_response_time = bot_response_time + 1
message_segment = Seg(type='emoji', data=emoji_cq)
bot_message = MessageSending(
message_id=think_id,
chat_stream=chat,
message_segment=message_segment,
reply=message,
is_head=False,
is_emoji=True
)
message_segment = Seg(type='emoji', data=emoji_cq) message_segment = Seg(type='emoji', data=emoji_cq)
bot_message = MessageSending( bot_message = MessageSending(
message_id=think_id, message_id=think_id,
@@ -238,5 +283,11 @@ class ChatBot:
group_info=groupinfo group_info=groupinfo
) )
willing_manager.change_reply_willing_after_sent(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo
)
# 创建全局ChatBot实例 # 创建全局ChatBot实例
chat_bot = ChatBot() chat_bot = ChatBot()

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
from typing import Optional, Union from typing import Optional, Union
from typing import Optional, Union
from ...common.database import Database from ...common.database import Database
from .message_base import UserInfo from .message_base import UserInfo
@@ -15,6 +16,7 @@ class Impression:
class Relationship: class Relationship:
user_id: int = None user_id: int = None
platform: str = None platform: str = None
platform: str = None
gender: str = None gender: str = None
age: int = None age: int = None
nickname: str = None nickname: str = None
@@ -33,6 +35,7 @@ class Relationship:
class RelationshipManager: class RelationshipManager:
def __init__(self): def __init__(self):
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键 self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
async def update_relationship(self, async def update_relationship(self,
chat_stream:ChatStream, chat_stream:ChatStream,
@@ -63,16 +66,23 @@ class RelationshipManager:
# 检查是否在内存中已存在 # 检查是否在内存中已存在
relationship = self.relationships.get(key) relationship = self.relationships.get(key)
relationship = self.relationships.get(key)
if relationship: if relationship:
# 如果存在,更新现有对象 # 如果存在,更新现有对象
if isinstance(data, dict): if isinstance(data, dict):
for k, value in data.items(): for k, value in data.items():
if hasattr(relationship, k) and value is not None: if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value) setattr(relationship, k, value)
for k, value in data.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else: else:
for k, value in kwargs.items(): for k, value in kwargs.items():
if hasattr(relationship, k) and value is not None: if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value) setattr(relationship, k, value)
for k, value in kwargs.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else: else:
# 如果不存在,创建新对象 # 如果不存在,创建新对象
if user_info is not None: if user_info is not None:
@@ -85,6 +95,16 @@ class RelationshipManager:
kwargs['user_id'] = user_id kwargs['user_id'] = user_id
relationship = Relationship(**kwargs) relationship = Relationship(**kwargs)
self.relationships[key] = relationship self.relationships[key] = relationship
if user_info is not None:
relationship = Relationship(user_info=user_info, **kwargs)
elif isinstance(data, dict):
data['platform'] = platform
relationship = Relationship(user_id=user_id, data=data)
else:
kwargs['platform'] = platform
kwargs['user_id'] = user_id
relationship = Relationship(**kwargs)
self.relationships[key] = relationship
# 保存到数据库 # 保存到数据库
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
@@ -92,6 +112,33 @@ class RelationshipManager:
return relationship return relationship
async def update_relationship_value(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None,
**kwargs) -> Optional[Relationship]:
"""更新关系值
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
async def update_relationship_value(self, async def update_relationship_value(self,
user_id: int = None, user_id: int = None,
platform: str = None, platform: str = None,
@@ -121,7 +168,10 @@ class RelationshipManager:
# 检查是否在内存中已存在 # 检查是否在内存中已存在
relationship = self.relationships.get(key) relationship = self.relationships.get(key)
relationship = self.relationships.get(key)
if relationship: if relationship:
for k, value in kwargs.items():
if k == 'relationship_value':
for k, value in kwargs.items(): for k, value in kwargs.items():
if k == 'relationship_value': if k == 'relationship_value':
relationship.relationship_value += value relationship.relationship_value += value
@@ -129,12 +179,41 @@ class RelationshipManager:
relationship.saved = True relationship.saved = True
return relationship return relationship
else: else:
# 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(user_info=user_info, **kwargs)
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
# 如果不存在且提供了user_info则创建新的关系 # 如果不存在且提供了user_info则创建新的关系
if user_info is not None: if user_info is not None:
return await self.update_relationship(user_info=user_info, **kwargs) return await self.update_relationship(user_info=user_info, **kwargs)
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新") print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
return None return None
def get_relationship(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> Optional[Relationship]:
"""获取用户关系对象
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key]
def get_relationship(self, def get_relationship(self,
user_id: int = None, user_id: int = None,
platform: str = None, platform: str = None,
@@ -169,10 +248,18 @@ class RelationshipManager:
if 'platform' not in data: if 'platform' not in data:
data['platform'] = 'qq' data['platform'] = 'qq'
rela = Relationship(data=data)
"""从数据库加载或创建新的关系对象"""
# 确保data中有platform字段如果没有则默认为'qq'
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data) rela = Relationship(data=data)
rela.saved = True rela.saved = True
key = (rela.user_id, rela.platform) key = (rela.user_id, rela.platform)
self.relationships[key] = rela self.relationships[key] = rela
key = (rela.user_id, rela.platform)
self.relationships[key] = rela
return rela return rela
async def load_all_relationships(self): async def load_all_relationships(self):
@@ -190,6 +277,7 @@ class RelationshipManager:
# 依次加载每条记录 # 依次加载每条记录
for data in all_relationships: for data in all_relationships:
await self.load_relationship(data) await self.load_relationship(data)
await self.load_relationship(data)
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录") print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
while True: while True:
@@ -200,15 +288,19 @@ class RelationshipManager:
async def _save_all_relationships(self): async def _save_all_relationships(self):
"""将所有关系数据保存到数据库""" """将所有关系数据保存到数据库"""
# 保存所有关系数据 # 保存所有关系数据
for (userid, platform), relationship in self.relationships.items():
for (userid, platform), relationship in self.relationships.items(): for (userid, platform), relationship in self.relationships.items():
if not relationship.saved: if not relationship.saved:
relationship.saved = True relationship.saved = True
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
async def storage_relationship(self, relationship: Relationship):
"""将关系记录存储到数据库中"""
async def storage_relationship(self, relationship: Relationship): async def storage_relationship(self, relationship: Relationship):
"""将关系记录存储到数据库中""" """将关系记录存储到数据库中"""
user_id = relationship.user_id user_id = relationship.user_id
platform = relationship.platform platform = relationship.platform
platform = relationship.platform
nickname = relationship.nickname nickname = relationship.nickname
relationship_value = relationship.relationship_value relationship_value = relationship.relationship_value
gender = relationship.gender gender = relationship.gender
@@ -217,8 +309,10 @@ class RelationshipManager:
db = Database.get_instance() db = Database.get_instance()
db.db.relationships.update_one( db.db.relationships.update_one(
{'user_id': user_id, 'platform': platform},
{'user_id': user_id, 'platform': platform}, {'user_id': user_id, 'platform': platform},
{'$set': { {'$set': {
'platform': platform,
'platform': platform, 'platform': platform,
'nickname': nickname, 'nickname': nickname,
'relationship_value': relationship_value, 'relationship_value': relationship_value,
@@ -229,6 +323,28 @@ class RelationshipManager:
upsert=True upsert=True
) )
def get_name(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> str:
"""获取用户昵称
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
str: 用户昵称
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
def get_name(self, def get_name(self,
user_id: int = None, user_id: int = None,
platform: str = None, platform: str = None,
@@ -254,6 +370,11 @@ class RelationshipManager:
# 确保user_id是整数类型 # 确保user_id是整数类型
user_id = int(user_id) user_id = int(user_id)
key = (user_id, platform) key = (user_id, platform)
if key in self.relationships:
return self.relationships[key].nickname
elif user_info is not None:
return user_info.user_nickname or user_info.user_cardname or "某人"
key = (user_id, platform)
if key in self.relationships: if key in self.relationships:
return self.relationships[key].nickname return self.relationships[key].nickname
elif user_info is not None: elif user_info is not None:

View File

@@ -1,4 +1,5 @@
from typing import Optional, Union from typing import Optional, Union
from typing import Optional, Union
from ...common.database import Database from ...common.database import Database
from .message_base import MessageBase from .message_base import MessageBase

View File

@@ -2,6 +2,9 @@ import asyncio
from typing import Dict from typing import Dict
from loguru import logger from loguru import logger
from typing import Dict
from loguru import logger
from .config import global_config from .config import global_config
from .message_base import UserInfo, GroupInfo from .message_base import UserInfo, GroupInfo
from .chat_stream import chat_manager,ChatStream from .chat_stream import chat_manager,ChatStream
@@ -9,6 +12,7 @@ from .chat_stream import chat_manager,ChatStream
class WillingManager: class WillingManager:
def __init__(self): def __init__(self):
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿 self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self._decay_task = None self._decay_task = None
self._started = False self._started = False
@@ -19,6 +23,8 @@ class WillingManager:
await asyncio.sleep(5) await asyncio.sleep(5)
for chat_id in self.chat_reply_willing: for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6) self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
def get_willing(self,chat_stream:ChatStream) -> float: def get_willing(self,chat_stream:ChatStream) -> float:
"""获取指定聊天流的回复意愿""" """获取指定聊天流的回复意愿"""
@@ -27,6 +33,9 @@ class WillingManager:
return self.chat_reply_willing.get(stream.stream_id, 0) return self.chat_reply_willing.get(stream.stream_id, 0)
return 0 return 0
def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿"""
self.chat_reply_willing[chat_id] = willing
def set_willing(self, chat_id: str, willing: float): def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿""" """设置指定聊天流的回复意愿"""
self.chat_reply_willing[chat_id] = willing self.chat_reply_willing[chat_id] = willing
@@ -81,6 +90,7 @@ class WillingManager:
if reply_probability < 0: if reply_probability < 0:
reply_probability = 0 reply_probability = 0
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
self.chat_reply_willing[chat_id] = min(current_willing, 3.0) self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
return reply_probability return reply_probability