Merge branch 'refractor' of https://github.com/tcmofashi/MaiMBot into refractor
This commit is contained in:
@@ -11,18 +11,25 @@ from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
|
||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||
from .llm_generator import ResponseGenerator
|
||||
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
from .message_cq import (
|
||||
MessageRecvCQ,
|
||||
MessageSendCQ,
|
||||
)
|
||||
from .chat_stream import chat_manager
|
||||
MessageRecvCQ,
|
||||
MessageSendCQ,
|
||||
)
|
||||
from .chat_stream import chat_manager
|
||||
from .message_sender import message_manager # 导入新的消息管理器
|
||||
from .relationship_manager import relationship_manager
|
||||
from .storage import MessageStorage
|
||||
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
|
||||
from .utils_image import image_path_to_base64
|
||||
from .utils_image import image_path_to_base64
|
||||
from .willing_manager import willing_manager # 导入意愿管理器
|
||||
from .message_base import UserInfo, GroupInfo, Seg
|
||||
from .message_base import UserInfo, GroupInfo, Seg
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self):
|
||||
@@ -47,6 +54,7 @@ class ChatBot:
|
||||
self.bot = bot # 更新 bot 实例
|
||||
|
||||
|
||||
|
||||
|
||||
group_info = await bot.get_group_info(group_id=event.group_id)
|
||||
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||
@@ -54,11 +62,15 @@ class ChatBot:
|
||||
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
||||
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
||||
|
||||
message_cq=MessageRecvCQ(
|
||||
message_cq=MessageRecvCQ(
|
||||
message_id=event.message_id,
|
||||
user_id=event.user_id,
|
||||
raw_message=str(event.original_message),
|
||||
group_id=event.group_id,
|
||||
user_id=event.user_id,
|
||||
raw_message=str(event.original_message),
|
||||
group_id=event.group_id,
|
||||
reply_message=event.reply,
|
||||
platform='qq'
|
||||
)
|
||||
@@ -88,12 +100,15 @@ class ChatBot:
|
||||
return
|
||||
# 过滤词
|
||||
for word in global_config.ban_words:
|
||||
if word in message.processed_plain_text:
|
||||
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
|
||||
if word in message.processed_plain_text:
|
||||
logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}")
|
||||
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered")
|
||||
return
|
||||
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
||||
|
||||
|
||||
|
||||
@@ -124,6 +139,11 @@ class ChatBot:
|
||||
response = None
|
||||
|
||||
if random() < reply_probability:
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.BOT_QQ,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
platform=messageinfo.platform
|
||||
)
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.BOT_QQ,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
@@ -136,6 +156,11 @@ class ChatBot:
|
||||
message_id=think_id,
|
||||
reply=message
|
||||
)
|
||||
thinking_message = MessageThinking.from_chat_stream(
|
||||
chat_stream=chat,
|
||||
message_id=think_id,
|
||||
reply=message
|
||||
)
|
||||
|
||||
message_manager.add_message(thinking_message)
|
||||
|
||||
@@ -146,6 +171,7 @@ class ChatBot:
|
||||
response,raw_content = await self.gpt.generate_response(message)
|
||||
|
||||
if response:
|
||||
container = message_manager.get_container(chat.stream_id)
|
||||
container = message_manager.get_container(chat.stream_id)
|
||||
thinking_message = None
|
||||
# 找到message,删除
|
||||
@@ -163,6 +189,7 @@ class ChatBot:
|
||||
#记录开始思考的时间,避免从思考到回复的时间太久
|
||||
thinking_start_time = thinking_message.thinking_start_time
|
||||
message_set = MessageSet(chat, think_id)
|
||||
message_set = MessageSet(chat, think_id)
|
||||
#计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
||||
accu_typing_time = 0
|
||||
|
||||
@@ -174,6 +201,8 @@ class ChatBot:
|
||||
accu_typing_time += typing_time
|
||||
timepoint = tinking_time_point + accu_typing_time
|
||||
|
||||
message_segment = Seg(type='text', data=msg)
|
||||
bot_message = MessageSending(
|
||||
message_segment = Seg(type='text', data=msg)
|
||||
bot_message = MessageSending(
|
||||
message_id=think_id,
|
||||
@@ -182,6 +211,12 @@ class ChatBot:
|
||||
reply=message,
|
||||
is_head=not mark_head,
|
||||
is_emoji=False
|
||||
)
|
||||
chat_stream=chat,
|
||||
message_segment=message_segment,
|
||||
reply=message,
|
||||
is_head=not mark_head,
|
||||
is_emoji=False
|
||||
)
|
||||
if not mark_head:
|
||||
mark_head = True
|
||||
@@ -200,6 +235,7 @@ class ChatBot:
|
||||
if emoji_raw != None:
|
||||
emoji_path,discription = emoji_raw
|
||||
|
||||
emoji_cq = image_path_to_base64(emoji_path)
|
||||
emoji_cq = image_path_to_base64(emoji_path)
|
||||
|
||||
if random() < 0.5:
|
||||
@@ -207,6 +243,15 @@ class ChatBot:
|
||||
else:
|
||||
bot_response_time = bot_response_time + 1
|
||||
|
||||
message_segment = Seg(type='emoji', data=emoji_cq)
|
||||
bot_message = MessageSending(
|
||||
message_id=think_id,
|
||||
chat_stream=chat,
|
||||
message_segment=message_segment,
|
||||
reply=message,
|
||||
is_head=False,
|
||||
is_emoji=True
|
||||
)
|
||||
message_segment = Seg(type='emoji', data=emoji_cq)
|
||||
bot_message = MessageSending(
|
||||
message_id=think_id,
|
||||
@@ -237,6 +282,12 @@ class ChatBot:
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo
|
||||
)
|
||||
|
||||
willing_manager.change_reply_willing_after_sent(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo
|
||||
)
|
||||
|
||||
# 创建全局ChatBot实例
|
||||
chat_bot = ChatBot()
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...common.database import Database
|
||||
from .message_base import UserInfo
|
||||
@@ -15,6 +16,7 @@ class Impression:
|
||||
class Relationship:
|
||||
user_id: int = None
|
||||
platform: str = None
|
||||
platform: str = None
|
||||
gender: str = None
|
||||
age: int = None
|
||||
nickname: str = None
|
||||
@@ -33,6 +35,7 @@ class Relationship:
|
||||
class RelationshipManager:
|
||||
def __init__(self):
|
||||
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
||||
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
||||
|
||||
async def update_relationship(self,
|
||||
chat_stream:ChatStream,
|
||||
@@ -63,16 +66,23 @@ class RelationshipManager:
|
||||
|
||||
# 检查是否在内存中已存在
|
||||
relationship = self.relationships.get(key)
|
||||
relationship = self.relationships.get(key)
|
||||
if relationship:
|
||||
# 如果存在,更新现有对象
|
||||
if isinstance(data, dict):
|
||||
for k, value in data.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
for k, value in data.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
else:
|
||||
for k, value in kwargs.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
for k, value in kwargs.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
else:
|
||||
# 如果不存在,创建新对象
|
||||
if user_info is not None:
|
||||
@@ -85,6 +95,16 @@ class RelationshipManager:
|
||||
kwargs['user_id'] = user_id
|
||||
relationship = Relationship(**kwargs)
|
||||
self.relationships[key] = relationship
|
||||
if user_info is not None:
|
||||
relationship = Relationship(user_info=user_info, **kwargs)
|
||||
elif isinstance(data, dict):
|
||||
data['platform'] = platform
|
||||
relationship = Relationship(user_id=user_id, data=data)
|
||||
else:
|
||||
kwargs['platform'] = platform
|
||||
kwargs['user_id'] = user_id
|
||||
relationship = Relationship(**kwargs)
|
||||
self.relationships[key] = relationship
|
||||
|
||||
# 保存到数据库
|
||||
await self.storage_relationship(relationship)
|
||||
@@ -92,6 +112,33 @@ class RelationshipManager:
|
||||
|
||||
return relationship
|
||||
|
||||
async def update_relationship_value(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None,
|
||||
**kwargs) -> Optional[Relationship]:
|
||||
"""更新关系值
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
**kwargs: 其他参数
|
||||
Returns:
|
||||
Relationship: 关系对象
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
# 使用(user_id, platform)作为键
|
||||
key = (user_id, platform)
|
||||
|
||||
async def update_relationship_value(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
@@ -121,7 +168,10 @@ class RelationshipManager:
|
||||
|
||||
# 检查是否在内存中已存在
|
||||
relationship = self.relationships.get(key)
|
||||
relationship = self.relationships.get(key)
|
||||
if relationship:
|
||||
for k, value in kwargs.items():
|
||||
if k == 'relationship_value':
|
||||
for k, value in kwargs.items():
|
||||
if k == 'relationship_value':
|
||||
relationship.relationship_value += value
|
||||
@@ -129,12 +179,41 @@ class RelationshipManager:
|
||||
relationship.saved = True
|
||||
return relationship
|
||||
else:
|
||||
# 如果不存在且提供了user_info,则创建新的关系
|
||||
if user_info is not None:
|
||||
return await self.update_relationship(user_info=user_info, **kwargs)
|
||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
||||
# 如果不存在且提供了user_info,则创建新的关系
|
||||
if user_info is not None:
|
||||
return await self.update_relationship(user_info=user_info, **kwargs)
|
||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新")
|
||||
return None
|
||||
|
||||
def get_relationship(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None) -> Optional[Relationship]:
|
||||
"""获取用户关系对象
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
Returns:
|
||||
Relationship: 关系对象
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
key = (user_id, platform)
|
||||
if key in self.relationships:
|
||||
return self.relationships[key]
|
||||
def get_relationship(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
@@ -169,10 +248,18 @@ class RelationshipManager:
|
||||
if 'platform' not in data:
|
||||
data['platform'] = 'qq'
|
||||
|
||||
rela = Relationship(data=data)
|
||||
"""从数据库加载或创建新的关系对象"""
|
||||
# 确保data中有platform字段,如果没有则默认为'qq'
|
||||
if 'platform' not in data:
|
||||
data['platform'] = 'qq'
|
||||
|
||||
rela = Relationship(data=data)
|
||||
rela.saved = True
|
||||
key = (rela.user_id, rela.platform)
|
||||
self.relationships[key] = rela
|
||||
key = (rela.user_id, rela.platform)
|
||||
self.relationships[key] = rela
|
||||
return rela
|
||||
|
||||
async def load_all_relationships(self):
|
||||
@@ -190,6 +277,7 @@ class RelationshipManager:
|
||||
# 依次加载每条记录
|
||||
for data in all_relationships:
|
||||
await self.load_relationship(data)
|
||||
await self.load_relationship(data)
|
||||
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
||||
|
||||
while True:
|
||||
@@ -200,15 +288,19 @@ class RelationshipManager:
|
||||
async def _save_all_relationships(self):
|
||||
"""将所有关系数据保存到数据库"""
|
||||
# 保存所有关系数据
|
||||
for (userid, platform), relationship in self.relationships.items():
|
||||
for (userid, platform), relationship in self.relationships.items():
|
||||
if not relationship.saved:
|
||||
relationship.saved = True
|
||||
await self.storage_relationship(relationship)
|
||||
|
||||
async def storage_relationship(self, relationship: Relationship):
|
||||
"""将关系记录存储到数据库中"""
|
||||
async def storage_relationship(self, relationship: Relationship):
|
||||
"""将关系记录存储到数据库中"""
|
||||
user_id = relationship.user_id
|
||||
platform = relationship.platform
|
||||
platform = relationship.platform
|
||||
nickname = relationship.nickname
|
||||
relationship_value = relationship.relationship_value
|
||||
gender = relationship.gender
|
||||
@@ -217,8 +309,10 @@ class RelationshipManager:
|
||||
|
||||
db = Database.get_instance()
|
||||
db.db.relationships.update_one(
|
||||
{'user_id': user_id, 'platform': platform},
|
||||
{'user_id': user_id, 'platform': platform},
|
||||
{'$set': {
|
||||
'platform': platform,
|
||||
'platform': platform,
|
||||
'nickname': nickname,
|
||||
'relationship_value': relationship_value,
|
||||
@@ -229,6 +323,28 @@ class RelationshipManager:
|
||||
upsert=True
|
||||
)
|
||||
|
||||
def get_name(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None) -> str:
|
||||
"""获取用户昵称
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
Returns:
|
||||
str: 用户昵称
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
def get_name(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
@@ -254,6 +370,11 @@ class RelationshipManager:
|
||||
# 确保user_id是整数类型
|
||||
user_id = int(user_id)
|
||||
key = (user_id, platform)
|
||||
if key in self.relationships:
|
||||
return self.relationships[key].nickname
|
||||
elif user_info is not None:
|
||||
return user_info.user_nickname or user_info.user_cardname or "某人"
|
||||
key = (user_id, platform)
|
||||
if key in self.relationships:
|
||||
return self.relationships[key].nickname
|
||||
elif user_info is not None:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...common.database import Database
|
||||
from .message_base import MessageBase
|
||||
|
||||
@@ -2,6 +2,9 @@ import asyncio
|
||||
from typing import Dict
|
||||
from loguru import logger
|
||||
|
||||
from typing import Dict
|
||||
from loguru import logger
|
||||
|
||||
from .config import global_config
|
||||
from .message_base import UserInfo, GroupInfo
|
||||
from .chat_stream import chat_manager,ChatStream
|
||||
@@ -9,6 +12,7 @@ from .chat_stream import chat_manager,ChatStream
|
||||
|
||||
class WillingManager:
|
||||
def __init__(self):
|
||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
||||
self._decay_task = None
|
||||
self._started = False
|
||||
@@ -19,6 +23,8 @@ class WillingManager:
|
||||
await asyncio.sleep(5)
|
||||
for chat_id in self.chat_reply_willing:
|
||||
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
||||
for chat_id in self.chat_reply_willing:
|
||||
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
||||
|
||||
def get_willing(self,chat_stream:ChatStream) -> float:
|
||||
"""获取指定聊天流的回复意愿"""
|
||||
@@ -27,6 +33,9 @@ class WillingManager:
|
||||
return self.chat_reply_willing.get(stream.stream_id, 0)
|
||||
return 0
|
||||
|
||||
def set_willing(self, chat_id: str, willing: float):
|
||||
"""设置指定聊天流的回复意愿"""
|
||||
self.chat_reply_willing[chat_id] = willing
|
||||
def set_willing(self, chat_id: str, willing: float):
|
||||
"""设置指定聊天流的回复意愿"""
|
||||
self.chat_reply_willing[chat_id] = willing
|
||||
@@ -81,6 +90,7 @@ class WillingManager:
|
||||
if reply_probability < 0:
|
||||
reply_probability = 0
|
||||
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||
return reply_probability
|
||||
|
||||
|
||||
Reference in New Issue
Block a user