diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 8359b9712..37f621bbb 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -13,6 +13,7 @@ from ..memory_system.memory import hippocampus from ..moods.moods import MoodManager # 导入情绪管理器 from .config import global_config from .cq_code import CQCode, cq_code_tool # 导入CQCode模块 +from .cq_code import CQCode, cq_code_tool # 导入CQCode模块 from .emoji_manager import emoji_manager # 导入表情包管理器 from .llm_generator import ResponseGenerator from .message import MessageSending, MessageRecv, MessageThinking, MessageSet @@ -30,6 +31,7 @@ from .willing_manager import willing_manager # 导入意愿管理器 from .message_base import UserInfo, GroupInfo, Seg + class ChatBot: def __init__(self): self.storage = MessageStorage() @@ -95,6 +97,7 @@ class ChatBot: # group_info = await bot.get_group_info(group_id=event.group_id) # sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) + message_cq = MessageRecvCQ( message_cq = MessageRecvCQ( message_id=event.message_id, user_info=user_info, @@ -102,18 +105,29 @@ class ChatBot: group_info=group_info, reply_message=event.reply, platform="qq", + platform="qq", ) message_json = message_cq.to_dict() + message_json = message_cq.to_dict() # 进入maimbot message = MessageRecv(message_json) + groupinfo = message.message_info.group_info + userinfo = message.message_info.user_info + messageinfo = message.message_info + message = MessageRecv(message_json) + groupinfo = message.message_info.group_info userinfo = message.message_info.user_info messageinfo = message.message_info # 消息过滤,涉及到config有待更新 + chat = await chat_manager.get_or_create_stream( + platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo + ) + chat = await chat_manager.get_or_create_stream( platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo ) @@ -150,6 +164,7 @@ class ChatBot: # topic=await topic_identifier.identify_topic_llm(message.processed_plain_text) topic = "" + topic = "" interested_rate = 0 interested_rate = ( await hippocampus.memory_activate_value(message.processed_plain_text) / 100 @@ -159,6 +174,8 @@ class ChatBot: await self.storage.store_message(message, chat, topic[0] if topic else None) + await self.storage.store_message(message, chat, topic[0] if topic else None) + is_mentioned = is_mentioned_bot_in_message(message) reply_probability = await willing_manager.change_reply_willing_received( chat_stream=chat, @@ -167,9 +184,11 @@ class ChatBot: config=global_config, is_emoji=message.is_emoji, interested_rate=interested_rate, + interested_rate=interested_rate, ) current_willing = willing_manager.get_willing(chat_stream=chat) + logger.info( f"[{current_time}][{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{chat.user_info.user_nickname}:" f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]" @@ -177,6 +196,7 @@ class ChatBot: response = None + if random() < reply_probability: bot_user_info = UserInfo( user_id=global_config.BOT_QQ, @@ -192,12 +212,16 @@ class ChatBot: reply=message, ) + message_manager.add_message(thinking_message) willing_manager.change_reply_willing_sent(chat) response, raw_content = await self.gpt.generate_response(message) + + response, raw_content = await self.gpt.generate_response(message) + # print(f"response: {response}") if response: # print(f"有response: {response}") @@ -257,7 +281,7 @@ class ChatBot: print(f"添加message_set到message_manager") message_manager.add_message(message_set) - bot_response_time = tinking_time_point + bot_response_time = thinking_time_point if random() < global_config.emoji_chance: emoji_raw = await emoji_manager.get_emoji_for_text(response) @@ -269,7 +293,7 @@ class ChatBot: emoji_cq = image_path_to_base64(emoji_path) if random() < 0.5: - bot_response_time = tinking_time_point - 1 + bot_response_time = thinking_time_point - 1 else: bot_response_time = bot_response_time + 1 diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index 185e98edf..0a8a71df3 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -231,7 +231,8 @@ class CQCode: group_info=group_info, ) content_seg = Seg( - type="seglist", data=message_obj.message_segment ) + type="seglist", data=[message_obj.message_segment] + ) else: content_seg = Seg(type="text", data="[空消息]") else: @@ -256,7 +257,7 @@ class CQCode: group_info=group_info, ) content_seg = Seg( - type="seglist", data=message_obj.message_segment + type="seglist", data=[message_obj.message_segment] ) else: content_seg = Seg(type="text", data="[空消息]") @@ -281,11 +282,12 @@ class CQCode: if self.reply_message.sender.user_id: message_obj = MessageRecvCQ( - user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.get("nickname",None)), + user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.nickname), message_id=self.reply_message.message_id, raw_message=str(self.reply_message.message), group_info=GroupInfo(group_id=self.reply_message.group_id), ) + segments = [] if message_obj.message_info.user_info.user_id == global_config.BOT_QQ: @@ -302,7 +304,7 @@ class CQCode: ) ) - segments.append(Seg(type="seglist", data=message_obj.message_segment)) + segments.append(Seg(type="seglist", data=[message_obj.message_segment])) segments.append(Seg(type="text", data="]")) return segments else: diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index e502e357a..32b0abb41 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -18,7 +18,50 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @dataclass -class MessageRecv(MessageBase): +class Message(MessageBase): + chat_stream: ChatStream=None + reply: Optional['Message'] = None + detailed_plain_text: str = "" + processed_plain_text: str = "" + + def __init__( + self, + message_id: str, + time: int, + chat_stream: ChatStream, + user_info: UserInfo, + message_segment: Optional[Seg] = None, + reply: Optional['MessageRecv'] = None, + detailed_plain_text: str = "", + processed_plain_text: str = "", + ): + # 构造基础消息信息 + message_info = BaseMessageInfo( + platform=chat_stream.platform, + message_id=message_id, + time=time, + group_info=chat_stream.group_info, + user_info=user_info + ) + + # 调用父类初始化 + super().__init__( + message_info=message_info, + message_segment=message_segment, + raw_message=None + ) + + self.chat_stream = chat_stream + # 文本处理相关属性 + self.processed_plain_text = processed_plain_text + self.detailed_plain_text = detailed_plain_text + + # 回复消息 + self.reply = reply + + +@dataclass +class MessageRecv(Message): """接收消息类,用于处理从MessageCQ序列化的消息""" def __init__(self, message_dict: Dict): @@ -27,24 +70,19 @@ class MessageRecv(MessageBase): Args: message_dict: MessageCQ序列化后的字典 """ - message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) - message_segment = Seg.from_dict(message_dict.get("message_segment", {})) - raw_message = message_dict.get("raw_message") - - super().__init__( - message_info=message_info, - message_segment=message_segment, - raw_message=raw_message, - ) - + self.message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {})) + self.message_segment = Seg.from_dict(message_dict.get('message_segment', {})) + self.raw_message = message_dict.get('raw_message') + # 处理消息内容 self.processed_plain_text = "" # 初始化为空字符串 - self.detailed_plain_text = "" # 初始化为空字符串 - self.is_emoji = False - - def update_chat_stream(self, chat_stream: ChatStream): - self.chat_stream = chat_stream - + self.detailed_plain_text = "" # 初始化为空字符串 + self.is_emoji=False + + + def update_chat_stream(self,chat_stream:ChatStream): + self.chat_stream=chat_stream + async def process(self) -> None: """处理消息内容,生成纯文本和详细文本 @@ -118,48 +156,7 @@ class MessageRecv(MessageBase): else f"{user_info.user_nickname}(ta的id:{user_info.user_id})" ) return f"[{time_str}] {name}: {self.processed_plain_text}\n" - - -@dataclass -class Message(MessageBase): - chat_stream: ChatStream = None - reply: Optional["Message"] = None - detailed_plain_text: str = "" - processed_plain_text: str = "" - - def __init__( - self, - message_id: str, - time: int, - chat_stream: ChatStream, - user_info: UserInfo, - message_segment: Optional[Seg] = None, - reply: Optional["MessageRecv"] = None, - detailed_plain_text: str = "", - processed_plain_text: str = "", - ): - # 构造基础消息信息 - message_info = BaseMessageInfo( - platform=chat_stream.platform, - message_id=message_id, - time=time, - group_info=chat_stream.group_info, - user_info=user_info, - ) - - # 调用父类初始化 - super().__init__( - message_info=message_info, message_segment=message_segment, raw_message=None - ) - - self.chat_stream = chat_stream - # 文本处理相关属性 - self.processed_plain_text = detailed_plain_text - self.detailed_plain_text = processed_plain_text - - # 回复消息 - self.reply = reply - + @dataclass class MessageProcessBase(Message): diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index e9d11f339..3424d662c 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -192,13 +192,11 @@ class LLM_request: logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}") # 对全局配置进行更新 - if hasattr(global_config, 'llm_normal') and global_config.llm_normal.get( - 'name') == old_model_name: + if global_config.llm_normal.get('name') == old_model_name: global_config.llm_normal['name'] = self.model_name logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}") - if hasattr(global_config, 'llm_reasoning') and global_config.llm_reasoning.get( - 'name') == old_model_name: + if global_config.llm_reasoning.get('name') == old_model_name: global_config.llm_reasoning['name'] = self.model_name logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}") @@ -216,6 +214,7 @@ class LLM_request: # 将流式输出转化为非流式输出 if stream_mode: + flag_delta_content_finished = False accumulated_content = "" async for line_bytes in response.content: line = line_bytes.decode("utf-8").strip() @@ -227,13 +226,25 @@ class LLM_request: break try: chunk = json.loads(data_str) - delta = chunk["choices"][0]["delta"] - delta_content = delta.get("content") - if delta_content is None: - delta_content = "" - accumulated_content += delta_content + if flag_delta_content_finished: + usage = chunk.get("usage", None) # 获取tokn用量 + else: + delta = chunk["choices"][0]["delta"] + delta_content = delta.get("content") + if delta_content is None: + delta_content = "" + accumulated_content += delta_content + # 检测流式输出文本是否结束 + finish_reason = chunk["choices"][0]["finish_reason"] + if finish_reason == "stop": + usage = chunk.get("usage", None) + if usage: + break + # 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk + flag_delta_content_finished = True + except Exception: - logger.exception("解析流式输出错") + logger.exception("解析流式输出错误") content = accumulated_content reasoning_content = "" think_match = re.search(r'(.*?)', content, re.DOTALL) @@ -242,7 +253,7 @@ class LLM_request: content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() # 构造一个伪result以便调用自定义响应处理器或默认处理器 result = { - "choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]} + "choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}], "usage": usage} return response_handler(result) if response_handler else self._default_response_handler( result, user_id, request_type, endpoint) else: