diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 5385f3afb..17aa0630d 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -11,18 +11,25 @@ from .cq_code import CQCode,cq_code_tool # 导入CQCode模块 from .emoji_manager import emoji_manager # 导入表情包管理器 from .llm_generator import ResponseGenerator from .message import MessageSending, MessageRecv, MessageThinking, MessageSet +from .message import MessageSending, MessageRecv, MessageThinking, MessageSet from .message_cq import ( MessageRecvCQ, MessageSendCQ, ) +from .chat_stream import chat_manager + MessageRecvCQ, + MessageSendCQ, +) from .chat_stream import chat_manager from .message_sender import message_manager # 导入新的消息管理器 from .relationship_manager import relationship_manager from .storage import MessageStorage from .utils import calculate_typing_time, is_mentioned_bot_in_txt from .utils_image import image_path_to_base64 +from .utils_image import image_path_to_base64 from .willing_manager import willing_manager # 导入意愿管理器 from .message_base import UserInfo, GroupInfo, Seg +from .message_base import UserInfo, GroupInfo, Seg class ChatBot: def __init__(self): @@ -47,6 +54,7 @@ class ChatBot: self.bot = bot # 更新 bot 实例 + group_info = await bot.get_group_info(group_id=event.group_id) sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) @@ -54,11 +62,15 @@ class ChatBot: await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info) await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5) + message_cq=MessageRecvCQ( message_cq=MessageRecvCQ( message_id=event.message_id, user_id=event.user_id, raw_message=str(event.original_message), group_id=event.group_id, + user_id=event.user_id, + raw_message=str(event.original_message), + group_id=event.group_id, reply_message=event.reply, platform='qq' ) @@ -88,12 +100,15 @@ class ChatBot: return # 过滤词 for word in global_config.ban_words: + if word in message.processed_plain_text: + logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}") if word in message.processed_plain_text: logger.info(f"\033[1;32m[{groupinfo.group_name}]{userinfo.user_nickname}:\033[0m {message.processed_plain_text}") logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word},filtered") return current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time)) + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time)) @@ -124,6 +139,11 @@ class ChatBot: response = None if random() < reply_probability: + bot_user_info=UserInfo( + user_id=global_config.BOT_QQ, + user_nickname=global_config.BOT_NICKNAME, + platform=messageinfo.platform + ) bot_user_info=UserInfo( user_id=global_config.BOT_QQ, user_nickname=global_config.BOT_NICKNAME, @@ -136,6 +156,11 @@ class ChatBot: message_id=think_id, reply=message ) + thinking_message = MessageThinking.from_chat_stream( + chat_stream=chat, + message_id=think_id, + reply=message + ) message_manager.add_message(thinking_message) @@ -146,6 +171,7 @@ class ChatBot: response,raw_content = await self.gpt.generate_response(message) if response: + container = message_manager.get_container(chat.stream_id) container = message_manager.get_container(chat.stream_id) thinking_message = None # 找到message,删除 @@ -163,6 +189,7 @@ class ChatBot: #记录开始思考的时间,避免从思考到回复的时间太久 thinking_start_time = thinking_message.thinking_start_time message_set = MessageSet(chat, think_id) + message_set = MessageSet(chat, think_id) #计算打字时间,1是为了模拟打字,2是避免多条回复乱序 accu_typing_time = 0 @@ -174,6 +201,8 @@ class ChatBot: accu_typing_time += typing_time timepoint = tinking_time_point + accu_typing_time + message_segment = Seg(type='text', data=msg) + bot_message = MessageSending( message_segment = Seg(type='text', data=msg) bot_message = MessageSending( message_id=think_id, @@ -182,6 +211,12 @@ class ChatBot: reply=message, is_head=not mark_head, is_emoji=False + ) + chat_stream=chat, + message_segment=message_segment, + reply=message, + is_head=not mark_head, + is_emoji=False ) if not mark_head: mark_head = True @@ -200,6 +235,7 @@ class ChatBot: if emoji_raw != None: emoji_path,discription = emoji_raw + emoji_cq = image_path_to_base64(emoji_path) emoji_cq = image_path_to_base64(emoji_path) if random() < 0.5: @@ -207,6 +243,15 @@ class ChatBot: else: bot_response_time = bot_response_time + 1 + message_segment = Seg(type='emoji', data=emoji_cq) + bot_message = MessageSending( + message_id=think_id, + chat_stream=chat, + message_segment=message_segment, + reply=message, + is_head=False, + is_emoji=True + ) message_segment = Seg(type='emoji', data=emoji_cq) bot_message = MessageSending( message_id=think_id, @@ -237,6 +282,12 @@ class ChatBot: user_info=userinfo, group_info=groupinfo ) + + willing_manager.change_reply_willing_after_sent( + platform=messageinfo.platform, + user_info=userinfo, + group_info=groupinfo + ) # 创建全局ChatBot实例 chat_bot = ChatBot() \ No newline at end of file diff --git a/src/plugins/chat/relationship_manager.py b/src/plugins/chat/relationship_manager.py index b56cdc6e5..17fc2ac6a 100644 --- a/src/plugins/chat/relationship_manager.py +++ b/src/plugins/chat/relationship_manager.py @@ -1,5 +1,6 @@ import asyncio from typing import Optional, Union +from typing import Optional, Union from ...common.database import Database from .message_base import UserInfo @@ -15,6 +16,7 @@ class Impression: class Relationship: user_id: int = None platform: str = None + platform: str = None gender: str = None age: int = None nickname: str = None @@ -33,6 +35,7 @@ class Relationship: class RelationshipManager: def __init__(self): self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键 + self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键 async def update_relationship(self, chat_stream:ChatStream, @@ -63,16 +66,23 @@ class RelationshipManager: # 检查是否在内存中已存在 relationship = self.relationships.get(key) + relationship = self.relationships.get(key) if relationship: # 如果存在,更新现有对象 if isinstance(data, dict): for k, value in data.items(): if hasattr(relationship, k) and value is not None: setattr(relationship, k, value) + for k, value in data.items(): + if hasattr(relationship, k) and value is not None: + setattr(relationship, k, value) else: for k, value in kwargs.items(): if hasattr(relationship, k) and value is not None: setattr(relationship, k, value) + for k, value in kwargs.items(): + if hasattr(relationship, k) and value is not None: + setattr(relationship, k, value) else: # 如果不存在,创建新对象 if user_info is not None: @@ -85,6 +95,16 @@ class RelationshipManager: kwargs['user_id'] = user_id relationship = Relationship(**kwargs) self.relationships[key] = relationship + if user_info is not None: + relationship = Relationship(user_info=user_info, **kwargs) + elif isinstance(data, dict): + data['platform'] = platform + relationship = Relationship(user_id=user_id, data=data) + else: + kwargs['platform'] = platform + kwargs['user_id'] = user_id + relationship = Relationship(**kwargs) + self.relationships[key] = relationship # 保存到数据库 await self.storage_relationship(relationship) @@ -92,6 +112,33 @@ class RelationshipManager: return relationship + async def update_relationship_value(self, + user_id: int = None, + platform: str = None, + user_info: UserInfo = None, + **kwargs) -> Optional[Relationship]: + """更新关系值 + Args: + user_id: 用户ID(可选,如果提供user_info则不需要) + platform: 平台(可选,如果提供user_info则不需要) + user_info: 用户信息对象(可选) + **kwargs: 其他参数 + Returns: + Relationship: 关系对象 + """ + # 确定user_id和platform + if user_info is not None: + user_id = user_info.user_id + platform = user_info.platform or 'qq' + else: + platform = platform or 'qq' + + if user_id is None: + raise ValueError("必须提供user_id或user_info") + + # 使用(user_id, platform)作为键 + key = (user_id, platform) + async def update_relationship_value(self, user_id: int = None, platform: str = None, @@ -121,7 +168,10 @@ class RelationshipManager: # 检查是否在内存中已存在 relationship = self.relationships.get(key) + relationship = self.relationships.get(key) if relationship: + for k, value in kwargs.items(): + if k == 'relationship_value': for k, value in kwargs.items(): if k == 'relationship_value': relationship.relationship_value += value @@ -129,12 +179,41 @@ class RelationshipManager: relationship.saved = True return relationship else: + # 如果不存在且提供了user_info,则创建新的关系 + if user_info is not None: + return await self.update_relationship(user_info=user_info, **kwargs) + print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新") # 如果不存在且提供了user_info,则创建新的关系 if user_info is not None: return await self.update_relationship(user_info=user_info, **kwargs) print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id}({platform}) 不存在,无法更新") return None + def get_relationship(self, + user_id: int = None, + platform: str = None, + user_info: UserInfo = None) -> Optional[Relationship]: + """获取用户关系对象 + Args: + user_id: 用户ID(可选,如果提供user_info则不需要) + platform: 平台(可选,如果提供user_info则不需要) + user_info: 用户信息对象(可选) + Returns: + Relationship: 关系对象 + """ + # 确定user_id和platform + if user_info is not None: + user_id = user_info.user_id + platform = user_info.platform or 'qq' + else: + platform = platform or 'qq' + + if user_id is None: + raise ValueError("必须提供user_id或user_info") + + key = (user_id, platform) + if key in self.relationships: + return self.relationships[key] def get_relationship(self, user_id: int = None, platform: str = None, @@ -169,10 +248,18 @@ class RelationshipManager: if 'platform' not in data: data['platform'] = 'qq' + rela = Relationship(data=data) + """从数据库加载或创建新的关系对象""" + # 确保data中有platform字段,如果没有则默认为'qq' + if 'platform' not in data: + data['platform'] = 'qq' + rela = Relationship(data=data) rela.saved = True key = (rela.user_id, rela.platform) self.relationships[key] = rela + key = (rela.user_id, rela.platform) + self.relationships[key] = rela return rela async def load_all_relationships(self): @@ -190,6 +277,7 @@ class RelationshipManager: # 依次加载每条记录 for data in all_relationships: await self.load_relationship(data) + await self.load_relationship(data) print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录") while True: @@ -200,15 +288,19 @@ class RelationshipManager: async def _save_all_relationships(self): """将所有关系数据保存到数据库""" # 保存所有关系数据 + for (userid, platform), relationship in self.relationships.items(): for (userid, platform), relationship in self.relationships.items(): if not relationship.saved: relationship.saved = True await self.storage_relationship(relationship) + async def storage_relationship(self, relationship: Relationship): + """将关系记录存储到数据库中""" async def storage_relationship(self, relationship: Relationship): """将关系记录存储到数据库中""" user_id = relationship.user_id platform = relationship.platform + platform = relationship.platform nickname = relationship.nickname relationship_value = relationship.relationship_value gender = relationship.gender @@ -217,8 +309,10 @@ class RelationshipManager: db = Database.get_instance() db.db.relationships.update_one( + {'user_id': user_id, 'platform': platform}, {'user_id': user_id, 'platform': platform}, {'$set': { + 'platform': platform, 'platform': platform, 'nickname': nickname, 'relationship_value': relationship_value, @@ -229,6 +323,28 @@ class RelationshipManager: upsert=True ) + def get_name(self, + user_id: int = None, + platform: str = None, + user_info: UserInfo = None) -> str: + """获取用户昵称 + Args: + user_id: 用户ID(可选,如果提供user_info则不需要) + platform: 平台(可选,如果提供user_info则不需要) + user_info: 用户信息对象(可选) + Returns: + str: 用户昵称 + """ + # 确定user_id和platform + if user_info is not None: + user_id = user_info.user_id + platform = user_info.platform or 'qq' + else: + platform = platform or 'qq' + + if user_id is None: + raise ValueError("必须提供user_id或user_info") + def get_name(self, user_id: int = None, platform: str = None, @@ -254,6 +370,11 @@ class RelationshipManager: # 确保user_id是整数类型 user_id = int(user_id) key = (user_id, platform) + if key in self.relationships: + return self.relationships[key].nickname + elif user_info is not None: + return user_info.user_nickname or user_info.user_cardname or "某人" + key = (user_id, platform) if key in self.relationships: return self.relationships[key].nickname elif user_info is not None: diff --git a/src/plugins/chat/storage.py b/src/plugins/chat/storage.py index 170b677dc..85abd5150 100644 --- a/src/plugins/chat/storage.py +++ b/src/plugins/chat/storage.py @@ -1,4 +1,5 @@ from typing import Optional, Union +from typing import Optional, Union from ...common.database import Database from .message_base import MessageBase diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py index e3d928c10..d430ac74d 100644 --- a/src/plugins/chat/willing_manager.py +++ b/src/plugins/chat/willing_manager.py @@ -2,6 +2,9 @@ import asyncio from typing import Dict from loguru import logger +from typing import Dict +from loguru import logger + from .config import global_config from .message_base import UserInfo, GroupInfo from .chat_stream import chat_manager,ChatStream @@ -9,6 +12,7 @@ from .chat_stream import chat_manager,ChatStream class WillingManager: def __init__(self): + self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿 self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿 self._decay_task = None self._started = False @@ -19,6 +23,8 @@ class WillingManager: await asyncio.sleep(5) for chat_id in self.chat_reply_willing: self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6) + for chat_id in self.chat_reply_willing: + self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6) def get_willing(self,chat_stream:ChatStream) -> float: """获取指定聊天流的回复意愿""" @@ -27,6 +33,9 @@ class WillingManager: return self.chat_reply_willing.get(stream.stream_id, 0) return 0 + def set_willing(self, chat_id: str, willing: float): + """设置指定聊天流的回复意愿""" + self.chat_reply_willing[chat_id] = willing def set_willing(self, chat_id: str, willing: float): """设置指定聊天流的回复意愿""" self.chat_reply_willing[chat_id] = willing @@ -81,6 +90,7 @@ class WillingManager: if reply_probability < 0: reply_probability = 0 + self.chat_reply_willing[chat_id] = min(current_willing, 3.0) self.chat_reply_willing[chat_id] = min(current_willing, 3.0) return reply_probability