From 972e6066e66de4a7b6afd1622700fb99a3ea8884 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Wed, 26 Feb 2025 18:12:28 +0800 Subject: [PATCH] v0.1 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 能跑但是没写部署教程,主题和记忆识别也没写完 --- .gitignore | 14 + README.md | 14 + bot.py | 22 + env.example | 22 + flagged/log.csv | 2 + kill_mongodb.bat | 6 + knowledge.bat | 5 + pyproject.toml | 8 + requirements.txt | Bin 0 -> 896 bytes runniuniu - db.bat | 1 + runniuniu.bat | 5 + src/common/__init__.py | 1 + src/common/database.py | 21 + src/plugins/chat/.stream.py | 201 ++++++++ src/plugins/chat/__init__.py | 82 +++ src/plugins/chat/bot.py | 224 ++++++++ src/plugins/chat/bot_config.toml | 73 +++ src/plugins/chat/config.py | 109 ++++ src/plugins/chat/cq_code.py | 422 +++++++++++++++ src/plugins/chat/emoji_manager.py | 414 +++++++++++++++ src/plugins/chat/gpt_response.py | 544 ++++++++++++++++++++ src/plugins/chat/group_info_manager.py | 107 ++++ src/plugins/chat/image_utils.py | 162 ++++++ src/plugins/chat/info_gui.py | 76 +++ src/plugins/chat/llm_generator.py | 108 ++++ src/plugins/chat/message.py | 318 ++++++++++++ src/plugins/chat/message_send_control.py | 322 ++++++++++++ src/plugins/chat/message_stream.py | 264 ++++++++++ src/plugins/chat/message_visualizer.py | 138 +++++ src/plugins/chat/prompt_builder.py | 193 +++++++ src/plugins/chat/relationship_manager.py | 200 +++++++ src/plugins/chat/storage.py | 48 ++ src/plugins/chat/topic_identifier.py | 96 ++++ src/plugins/chat/utils.py | 115 +++++ src/plugins/chat/willing_manager.py | 77 +++ src/plugins/schedule/schedule_generator.py | 156 ++++++ src/plugins/schedule/schedule_llm_module.py | 55 ++ 37 files changed, 4625 insertions(+) create mode 100644 bot.py create mode 100644 env.example create mode 100644 flagged/log.csv create mode 100644 kill_mongodb.bat create mode 100644 knowledge.bat create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 runniuniu - db.bat create mode 100644 runniuniu.bat create mode 100644 src/common/__init__.py create mode 100644 src/common/database.py create mode 100644 src/plugins/chat/.stream.py create mode 100644 src/plugins/chat/__init__.py create mode 100644 src/plugins/chat/bot.py create mode 100644 src/plugins/chat/bot_config.toml create mode 100644 src/plugins/chat/config.py create mode 100644 src/plugins/chat/cq_code.py create mode 100644 src/plugins/chat/emoji_manager.py create mode 100644 src/plugins/chat/gpt_response.py create mode 100644 src/plugins/chat/group_info_manager.py create mode 100644 src/plugins/chat/image_utils.py create mode 100644 src/plugins/chat/info_gui.py create mode 100644 src/plugins/chat/llm_generator.py create mode 100644 src/plugins/chat/message.py create mode 100644 src/plugins/chat/message_send_control.py create mode 100644 src/plugins/chat/message_stream.py create mode 100644 src/plugins/chat/message_visualizer.py create mode 100644 src/plugins/chat/prompt_builder.py create mode 100644 src/plugins/chat/relationship_manager.py create mode 100644 src/plugins/chat/storage.py create mode 100644 src/plugins/chat/topic_identifier.py create mode 100644 src/plugins/chat/utils.py create mode 100644 src/plugins/chat/willing_manager.py create mode 100644 src/plugins/schedule/schedule_generator.py create mode 100644 src/plugins/schedule/schedule_llm_module.py diff --git a/.gitignore b/.gitignore index 15201acc1..71a620d7e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,17 @@ +data/ +mongodb/ +NapCat.Framework.Windows.Once/ +log/ +src/plugins/memory +/test +message_queue_content.txt +message_queue_content.bat +message_queue_window.bat +message_queue_window.txt +reasoning_content.txt +reasoning_content.bat +reasoning_window.bat + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index 46053e515..d775ad7d1 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,18 @@ # MaiMBot 麦麦 qq机器人 +<<<<<<< Updated upstream 还在整理中 +======= +*还没整理完* + +基于napcat,nonebot和mongodb + +要把env.example改成.env,并填上你的apikey(硅基流动) + +未完成整理,可以跑,但是还没写部署方式 +功能和文件结构随时可能发生变化 + +主要代码在/src/plugins/chat下 + +>>>>>>> Stashed changes diff --git a/bot.py b/bot.py new file mode 100644 index 000000000..f9544f404 --- /dev/null +++ b/bot.py @@ -0,0 +1,22 @@ +import nonebot +from nonebot.adapters.onebot.v11 import Adapter + +# 初始化 NoneBot +nonebot.init( + # napcat 默认使用 8080 端口 + websocket_port=8080, + # 设置日志级别 + log_level="INFO", + # 设置超级用户 + superusers={"你的QQ号"} +) + +# 注册适配器 +driver = nonebot.get_driver() +driver.register_adapter(Adapter) + +# 加载插件 +nonebot.load_plugins("src/plugins") + +if __name__ == "__main__": + nonebot.run() \ No newline at end of file diff --git a/env.example b/env.example new file mode 100644 index 000000000..7680540df --- /dev/null +++ b/env.example @@ -0,0 +1,22 @@ +ENVIRONMENT=dev +HOST=127.0.0.1 +PORT=8080 + +COMMAND_START=["/"] + +# 插件配置 +PLUGINS=["src2.plugins.chat"] + +# 默认配置 +MONGODB_HOST=127.0.0.1 +MONGODB_PORT=27017 +DATABASE_NAME=MegBot + +#key and url +CHAT_ANY_WHERE_KEY= +SILICONFLOW_KEY= +CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 +SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ + + + diff --git a/flagged/log.csv b/flagged/log.csv new file mode 100644 index 000000000..daeef4a9b --- /dev/null +++ b/flagged/log.csv @@ -0,0 +1,2 @@ +输入消息,推理内容,flag,username,timestamp +显示内容,,,,2025-02-18 16:50:53.643238 diff --git a/kill_mongodb.bat b/kill_mongodb.bat new file mode 100644 index 000000000..366f05d32 --- /dev/null +++ b/kill_mongodb.bat @@ -0,0 +1,6 @@ +@echo off +echo 正在查找并结束所有 MongoDB 进程... +taskkill /F /IM mongod.exe +taskkill /F /IM mongo.exe +echo MongoDB 进程已结束 +pause \ No newline at end of file diff --git a/knowledge.bat b/knowledge.bat new file mode 100644 index 000000000..e6ad209e4 --- /dev/null +++ b/knowledge.bat @@ -0,0 +1,5 @@ +call conda activate niuniu +cd "C:\GitHub\MegMeg-bot" + +REM 执行nb run命令 +nb run \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..4f06cd5ae --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[project] +name = "Megbot" +version = "0.1.0" +description = "New Bot Project" + +[tool.nonebot] +plugins = ["src.plugins.chat"] +plugin_dirs = ["src/plugins"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b74e55a31d6e2202e45a958268d1633effc6a20d GIT binary patch literal 896 zcmZuvJx>Br5S+qTnAlpISc80tSeS^NiAD*D#zM^jCn^UVZ~`_OOC?s;{vCUJTN@K& zsj*OJ-s3=oyzIfJ~*bw^2>?Qt}HgN zh)td?e70LTb=Ju*Hirn(CcdADCGZJ6OT*n0Xe+Ev@7J0KA+g~!Frm- z4)POkj`h**TrARECTf!*KZGUSqx0hEDKBM1UDkDVWLPacdlL#objV!kD-3U<)pcyL O70!_{c`=F6W`6 "Database": + if cls._instance is None: + cls._instance = cls(host, port, db_name) + return cls._instance + + @classmethod + def get_instance(cls) -> "Database": + if cls._instance is None: + raise RuntimeError("Database not initialized") + return cls._instance \ No newline at end of file diff --git a/src/plugins/chat/.stream.py b/src/plugins/chat/.stream.py new file mode 100644 index 000000000..44fed5b70 --- /dev/null +++ b/src/plugins/chat/.stream.py @@ -0,0 +1,201 @@ +from typing import Dict, List, Optional +from dataclasses import dataclass +import time +import threading +import asyncio +from .message import Message +from .storage import MessageStorage +from .topic_identifier import TopicIdentifier +from ...common.database import Database +import random + +@dataclass +class Topic: + id: str + name: str + messages: List[Message] + created_time: float + last_active_time: float + message_count: int + is_active: bool = True + +class MessageStream: + def __init__(self): + self.storage = MessageStorage() + self.active_topics: Dict[int, List[Topic]] = {} # group_id -> topics + self.topic_identifier = TopicIdentifier() + self.db = Database.get_instance() + self.topic_lock = threading.Lock() + + async def start(self): + """异步初始化""" + asyncio.create_task(self._monitor_topics()) + + async def _monitor_topics(self): + """定时监控主题状态""" + while True: + await asyncio.sleep(30) + self._print_active_topics() + self._check_inactive_topics() + self._remove_small_topic() + + def _print_active_topics(self): + """打印当前活跃主题""" + print("\n" + "="*50) + print("\033[1;36m【当前活跃主题】\033[0m") # 青色 + for group_id, topics in self.active_topics.items(): + active_topics = [t for t in topics if t.is_active] + if active_topics: + print(f"\n\033[1;33m群组 {group_id}:\033[0m") # 黄色 + for topic in active_topics: + print(f"\033[1;32m- {topic.name}\033[0m (消息数: {topic.message_count})") # 绿色 + + def _check_inactive_topics(self): + """检查并处理不活跃主题""" + current_time = time.time() + INACTIVE_TIME = 600 # 60秒内没有新增内容 + # MAX_MESSAGES_WITHOUT_TOPIC = 5 # 最新5条消息都不是这个主题就归档 + + with self.topic_lock: + for group_id, topics in self.active_topics.items(): + + for topic in topics: + if not topic.is_active: + continue + + # 检查是否超过不活跃时间 + time_inactive = current_time - topic.last_active_time + if time_inactive > INACTIVE_TIME: + # print(f"\033[1;33m[主题超时]\033[0m {topic.name} 已有 {int(time_inactive)} 秒未更新") + self._archive_topic(group_id, topic) + topic.is_active = False + continue + + + def _archive_topic(self, group_id: int, topic: Topic): + """将主题存档到数据库""" + # 查找是否有同名主题 + existing_topic = self.db.db.archived_topics.find_one({ + "name": topic.name + }) + + if existing_topic: + # 合并消息列表并去重 + existing_messages = existing_topic.get("messages", []) + new_messages = [ + { + "user_id": msg.user_id, + "plain_text": msg.plain_text, + "time": msg.time + } for msg in topic.messages + ] + + # 使用集合去重 + seen_texts = set() + unique_messages = [] + + # 先处理现有消息 + for msg in existing_messages: + if msg["plain_text"] not in seen_texts: + seen_texts.add(msg["plain_text"]) + unique_messages.append(msg) + + # 再处理新消息 + for msg in new_messages: + if msg["plain_text"] not in seen_texts: + seen_texts.add(msg["plain_text"]) + unique_messages.append(msg) + + # 更新主题信息 + self.db.db.archived_topics.update_one( + {"_id": existing_topic["_id"]}, + { + "$set": { + "messages": unique_messages, + "message_count": len(unique_messages), + "last_active_time": max(existing_topic["last_active_time"], topic.last_active_time), + "last_merged_time": time.time() + } + } + ) + print(f"\033[1;33m[主题合并]\033[0m 主题 {topic.name} 已合并,总消息数: {len(unique_messages)}") + + else: + # 存储新主题 + self.db.db.archived_topics.insert_one({ + "topic_id": topic.id, + "name": topic.name, + "messages": [ + { + "user_id": msg.user_id, + "plain_text": msg.plain_text, + "time": msg.time + } for msg in topic.messages + ], + "created_time": topic.created_time, + "last_active_time": topic.last_active_time, + "message_count": topic.message_count + }) + print(f"\033[1;32m[主题存档]\033[0m {topic.name} (群组: {group_id})") + + async def process_message(self, message: Message,topic:List[str]): + """处理新消息,返回识别出的主题列表""" + # 存储消息(包含主题) + await self.storage.store_message(message, topic) + self._update_topics(message.group_id, topic, message) + + def _update_topics(self, group_id: int, topic_names: List[str], message: Message) -> None: + """更新群组主题""" + current_time = time.time() + + # 确保群组存在 + if group_id not in self.active_topics: + self.active_topics[group_id] = [] + + # 查找现有主题 + for topic_name in topic_names: + for topic in self.active_topics[group_id]: + if topic.name == topic_name: + topic.messages.append(message) + topic.last_active_time = current_time + topic.message_count += 1 + print(f"\033[1;35m[更新主题]\033[0m {topic_name}") # 绿色 + break + else: + # 创建新主题 + new_topic = Topic( + id=f"{group_id}_{int(current_time)}", + name=topic_name, + messages=[message], + created_time=current_time, + last_active_time=current_time, + message_count=1 + ) + self.active_topics[group_id].append(new_topic) + + self._check_inactive_topics() + + def _remove_small_topic(self): + """随机移除一个12小时内没有新增内容的小主题""" + try: + current_time = time.time() + inactive_time = 12 * 3600 # 24小时 + + # 获取所有符合条件的主题 + topics = list(self.db.db.archived_topics.find({ + "message_count": {"$lt": 3}, # 消息数小于2 + "last_active_time": {"$lt": current_time - inactive_time} + })) + + if not topics: + return + + # 随机选择一个主题删除 + topic_to_remove = random.choice(topics) + inactive_hours = (current_time - topic_to_remove.get("last_active_time", 0)) / 3600 + + self.db.db.archived_topics.delete_one({"_id": topic_to_remove["_id"]}) + print(f"\033[1;31m[主题清理]\033[0m 已移除小主题: {topic_to_remove['name']} " + f"不活跃时间: {int(inactive_hours)}小时)") + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 移除小主题失败: {str(e)}") diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py new file mode 100644 index 000000000..de1fa4e19 --- /dev/null +++ b/src/plugins/chat/__init__.py @@ -0,0 +1,82 @@ +from nonebot import on_message, on_command, require, get_driver +from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment +from nonebot.typing import T_State +from ...common.database import Database +from .config import global_config +import os +import asyncio +import random +from .relationship_manager import relationship_manager +from ..schedule.schedule_generator import bot_schedule +from .willing_manager import willing_manager + + + +# 获取驱动器 +driver = get_driver() + +Database.initialize( + global_config.MONGODB_HOST, + global_config.MONGODB_PORT, + global_config.DATABASE_NAME +) + +print("\033[1;32m[初始化配置和数据库完成]\033[0m") + + +# 导入其他模块 +from .bot import ChatBot +from .emoji_manager import emoji_manager +from .message_send_control import message_sender +from .relationship_manager import relationship_manager + +# 初始化表情管理器 +emoji_manager.initialize() + +print("\033[1;32m正在唤醒麦麦......\033[0m") +# 创建机器人实例 +chat_bot = ChatBot(global_config) + +# 注册消息处理器 +group_msg = on_message() + +# 创建定时任务 +scheduler = require("nonebot_plugin_apscheduler").scheduler + +# 启动后台任务 +@driver.on_startup +async def start_background_tasks(): + """启动后台任务""" + # 只启动表情包管理任务 + asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL)) + + bot_schedule.print_schedule() + +@driver.on_bot_connect +async def _(bot: Bot): + """Bot连接成功时的处理""" + print("\033[1;38;5;208m-----------麦麦成功连接!-----------\033[0m") + message_sender.set_bot(bot) + asyncio.create_task(message_sender.start_processor(bot)) + await willing_manager.ensure_started() + print("\033[1;38;5;208m-----------麦麦消息发送器已启动!-----------\033[0m") + + asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL)) + print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m") + # 启动消息发送控制任务 + +@driver.on_startup +async def init_relationships(): + """在 NoneBot2 启动时初始化关系管理器""" + print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...") + asyncio.create_task(relationship_manager._start_relationship_manager()) + +@group_msg.handle() +async def _(bot: Bot, event: GroupMessageEvent, state: T_State): + await chat_bot.handle_message(event, bot) + +@scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships") +async def monitor_relationships(): + """每15秒打印一次关系数据""" + relationship_manager.print_all_relationships() + diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py new file mode 100644 index 000000000..7ef675f67 --- /dev/null +++ b/src/plugins/chat/bot.py @@ -0,0 +1,224 @@ +from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message as EventMessage, Bot +from .message import Message,MessageSet +from .config import BotConfig, global_config +from .storage import MessageStorage +from .gpt_response import GPTResponseGenerator +from .message_stream import MessageStream, MessageStreamContainer +from .topic_identifier import topic_identifier +from random import random +from nonebot.log import logger +from .group_info_manager import GroupInfoManager # 导入群信息管理器 +from .emoji_manager import emoji_manager # 导入表情包管理器 +import time +import os +from .cq_code import CQCode # 导入CQCode模块 +from .message_send_control import message_sender # 导入消息发送控制器 +from .message import Message_Thinking # 导入 Message_Thinking 类 +from .relationship_manager import relationship_manager +from .prompt_builder import prompt_builder +from .willing_manager import willing_manager # 导入意愿管理器 + + +class ChatBot: + def __init__(self, config: BotConfig): + self.config = config + self.storage = MessageStorage() + self.gpt = GPTResponseGenerator(config) + self.group_info_manager = GroupInfoManager() # 初始化群信息管理器 + self.bot = None # bot 实例引用 + self._started = False + + self.emoji_chance = 0.2 # 发送表情包的基础概率 + self.message_streams = MessageStreamContainer() + self.message_sender = message_sender + + async def _ensure_started(self): + """确保所有任务已启动""" + if not self._started: + # 只保留必要的任务 + self._started = True + + def is_mentioned_bot(self, message: Message) -> bool: + """检查消息是否提到了机器人""" + keywords = ['麦麦'] + for keyword in keywords: + if keyword in message.processed_plain_text: + return True + return False + + + async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None: + """处理收到的群消息""" + + if event.group_id not in self.config.talk_allowed_groups: + return + self.bot = bot # 更新 bot 实例 + + # 打印原始消息内容 + ''' + print(f"\n\033[1;33m[消息详情]\033[0m") + # print(f"- 原始消息: {str(event.raw_message)}") + print(f"- post_type: {event.post_type}") + print(f"- sub_type: {event.sub_type}") + print(f"- user_id: {event.user_id}") + print(f"- message_type: {event.message_type}") + # print(f"- message_id: {event.message_id}") + # print(f"- message: {event.message}") + print(f"- original_message: {event.original_message}") + print(f"- raw_message: {event.raw_message}") + # print(f"- font: {event.font}") + print(f"- sender: {event.sender}") + # print(f"- to_me: {event.to_me}") + + if event.reply: + print(f"\n\033[1;33m[回复消息详情]\033[0m") + # print(f"- message_id: {event.reply.message_id}") + print(f"- message_type: {event.reply.message_type}") + print(f"- sender: {event.reply.sender}") + # print(f"- time: {event.reply.time}") + print(f"- message: {event.reply.message}") + print(f"- raw_message: {event.reply.raw_message}") + # print(f"- original_message: {event.reply.original_message}") + ''' + + # 获取群组信息,发送消息的用户信息,并对数据库内容做一次更新 + + group_info = await bot.get_group_info(group_id=event.group_id) + await self.group_info_manager.update_group_info( + group_id=event.group_id, + group_name=group_info['group_name'], + member_count=group_info['member_count'] + ) + + + sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) + + # print(f"\033[1;32m[关系管理]\033[0m 更新关系: {sender_info}") + + await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info) + await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5) + print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}") + + + + message = Message( + group_id=event.group_id, + user_id=event.user_id, + message_id=event.message_id, + raw_message=str(event.original_message), + plain_text=event.get_plaintext(), + reply_message=event.reply, + ) + + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) + + topic = topic_identifier.identify_topic_jieba(message.processed_plain_text) + print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}") + + await self.storage.store_message(message, topic[0] if topic else None) + + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) + + print(f"\033[1;34m[调试]\033[0m 当前消息是否是表情包: {message.is_emoji}") + + is_mentioned = self.is_mentioned_bot(message) + reply_probability = willing_manager.change_reply_willing_received( + event.group_id, + topic[0] if topic else None, + is_mentioned, + self.config, + event.user_id, + message.is_emoji + ) + current_willing = willing_manager.get_willing(event.group_id) + + + print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability:.1f}]\033[0m") + response = "" + if random() < reply_probability: + + tinking_time_point = round(time.time(), 2) + think_id = 'mt' + str(tinking_time_point) + + thinking_message = Message_Thinking(message=message,message_id=think_id) + + message_sender.send_temp_container.add_message(thinking_message) + + willing_manager.change_reply_willing_sent(thinking_message.group_id) + # 生成回复 + response, emotion = await self.gpt.generate_response(message) + + # 如果生成了回复,发送并记录 + if response: + message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id) + if isinstance(response, list): + # 将多条消息合并成一条 + for msg in response: + # print(f"\033[1;34m[调试]\033[0m 载入消息消息: {msg}") + # bot_response_time = round(time.time(), 2) + timepoint = tinking_time_point-0.3 + bot_message = Message( + group_id=event.group_id, + user_id=self.config.BOT_QQ, + message_id=think_id, + message_based_id=event.message_id, + raw_message=msg, + plain_text=msg, + processed_plain_text=msg, + user_nickname="麦麦", + group_name=message.group_name, + time=timepoint + ) + # print(f"\033[1;34m[调试]\033[0m 添加消息到消息组: {bot_message}") + message_set.add_message(bot_message) + # print(f"\033[1;34m[调试]\033[0m 输入消息组: {message_set}") + message_sender.send_temp_container.update_thinking_message(message_set) + else: + # bot_response_time = round(time.time(), 2) + bot_message = Message( + group_id=event.group_id, + user_id=self.config.BOT_QQ, + message_id=think_id, + message_based_id=event.message_id, + raw_message=response, + plain_text=response, + processed_plain_text=response, + user_nickname="麦麦", + group_name=message.group_name, + time=tinking_time_point + ) + # print(f"\033[1;34m[调试]\033[0m 更新单条消息: {bot_message}") + message_sender.send_temp_container.update_thinking_message(bot_message) + + + bot_response_time = tinking_time_point + if random() < self.config.emoji_chance: + emoji_path = await emoji_manager.get_emoji_for_emotion(emotion) + if emoji_path: + emoji_cq = CQCode.create_emoji_cq(emoji_path) + + if random() < 0.5: + bot_response_time = tinking_time_point - 1 + # else: + # bot_response_time = bot_response_time + 1 + + bot_message = Message( + group_id=event.group_id, + user_id=self.config.BOT_QQ, + message_id=0, + raw_message=emoji_cq, + plain_text=emoji_cq, + processed_plain_text=emoji_cq, + user_nickname="麦麦", + group_name=message.group_name, + time=bot_response_time, + is_emoji=True + ) + message_sender.send_temp_container.add_message(bot_message) + + + + + + # 如果收到新消息,提高回复意愿 + willing_manager.change_reply_willing_after_sent(event.group_id) \ No newline at end of file diff --git a/src/plugins/chat/bot_config.toml b/src/plugins/chat/bot_config.toml new file mode 100644 index 000000000..539784b92 --- /dev/null +++ b/src/plugins/chat/bot_config.toml @@ -0,0 +1,73 @@ +[database] +host = "127.0.0.1" +port = 27017 +name = "MegBot" + +[bot] +qq = 2814567326 + +[message] +min_text_length = 2 +max_context_size = 15 +emoji_chance = 0.2 + +[emoji] +check_interval = 120 +register_interval = 10 + +[response] +model_r1_probability = 0.2 + + +[groups] +read_allowed = [ + 1030993430, #bot_test_group_1 + # 1015816696, #m43white + 739044565, #my_group + 192194125, #ms + 591693379, #bot_test_group_2 + 179648561, #nkyy + 764408046, #daily_news + 435591861, #m43black + 851345375, #hjy群 + 708847644, #rotate_cmy + 534940728, #bh_llh_HYY + # 549292720, #mrfz + # 231561425, #粉丝群 + 975992476, + 1140700103, + 752426484,#nd1 + 115843978,#nd2 + # 168718420 #bh +] + +talk_allowed = [ + 1030993430, #bot_test_group_1 + # 1015816696, #m43white + 739044565, #my_group + 192194125, #ms + 591693379, #bot_test_group_2 + 179648561, #nkyy + 764408046, #daily_news + #435591861, #m43black + 851345375, #hjy群 + 708847644, #rotate_cmy + 534940728, #bh_llh_HYY + # 231561425, #粉丝群 + 975992476, + 1140700103, + # 168718420#bh + # 752426484,#nd1 + # 115843978,#nd2 +] + +talk_frequency_down = [ + 549292720, #mrfz + 435591861, #m43black + # 231561425, + 975992476, + 1140700103, + 534940728 + # 752426484,#nd1 + # 115843978,#nd2 +] diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py new file mode 100644 index 000000000..06ac85b5a --- /dev/null +++ b/src/plugins/chat/config.py @@ -0,0 +1,109 @@ +from dataclasses import dataclass +from typing import Dict, Any, Optional +import os +from nonebot.log import logger, default_format +import logging +import configparser # 添加这行导入 +import tomli # 添加这行导入 + +# 禁用默认的日志输出 +# logger.remove() + +# # 只禁用 INFO 级别的日志输出到控制台 +# logging.getLogger('nonebot').handlers.clear() +# console_handler = logging.StreamHandler() +# console_handler.setLevel(logging.WARNING) # 只输出 WARNING 及以上级别 +# logging.getLogger('nonebot').addHandler(console_handler) +# logging.getLogger('nonebot').setLevel(logging.WARNING) + +@dataclass +class BotConfig: + """机器人配置类""" + + # 基础配置 + MONGODB_HOST: str = "127.0.0.1" + MONGODB_PORT: int = 27017 + DATABASE_NAME: str = "MegBot" + + BOT_QQ: Optional[int] = None + + # 消息处理相关配置 + MIN_TEXT_LENGTH: int = 2 # 最小处理文本长度 + MAX_CONTEXT_SIZE: int = 15 # 上下文最大消息数 + emoji_chance: float = 0.2 # 发送表情包的基础概率 + + read_allowed_groups = set() + talk_allowed_groups = set() + talk_frequency_down_groups = set() + + EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) + EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) + + MODEL_R1_PROBABILITY: float = 0.3 # R1模型概率 + + @classmethod + def load_config(cls, config_path: str = "bot_config.toml") -> "BotConfig": + """从TOML配置文件加载配置""" + config = cls() + if os.path.exists(config_path): + with open(config_path, "rb") as f: + toml_dict = tomli.load(f) + + # 数据库配置 + if "database" in toml_dict: + db_config = toml_dict["database"] + config.MONGODB_HOST = db_config.get("host", config.MONGODB_HOST) + config.MONGODB_PORT = db_config.get("port", config.MONGODB_PORT) + config.DATABASE_NAME = db_config.get("name", config.DATABASE_NAME) + + if "emoji" in toml_dict: + emoji_config = toml_dict["emoji"] + config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL) + config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL) + + # 机器人基础配置 + if "bot" in toml_dict: + bot_config = toml_dict["bot"] + bot_qq = bot_config.get("qq") + config.BOT_QQ = int(bot_qq) + + + if "response" in toml_dict: + response_config = toml_dict["response"] + config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY) + + # 消息配置 + if "message" in toml_dict: + msg_config = toml_dict["message"] + config.MIN_TEXT_LENGTH = msg_config.get("min_text_length", config.MIN_TEXT_LENGTH) + config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE) + config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance) + + # 群组配置 + if "groups" in toml_dict: + groups_config = toml_dict["groups"] + config.read_allowed_groups = set(groups_config.get("read_allowed", [])) + config.talk_allowed_groups = set(groups_config.get("talk_allowed", [])) + config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", [])) + + print(f"\033[1;32m成功加载配置文件: {config_path}\033[0m") + + return config + +global_config = BotConfig.load_config("./src/plugins/chat/bot_config.toml") + +from dotenv import load_dotenv +current_dir = os.path.dirname(os.path.abspath(__file__)) +root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) +load_dotenv(os.path.join(root_dir, '.env')) + +@dataclass +class LLMConfig: + """机器人配置类""" + # 基础配置 + SILICONFLOW_API_KEY: str = None + SILICONFLOW_BASE_URL: str = None + +llm_config = LLMConfig() +llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY') +llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL') diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py new file mode 100644 index 000000000..b9a92dec9 --- /dev/null +++ b/src/plugins/chat/cq_code.py @@ -0,0 +1,422 @@ +from dataclasses import dataclass +from typing import Dict, Optional +import html +import requests +import base64 +from PIL import Image +import io +from .image_utils import storage_compress_image, storage_emoji +import os +from random import random +from nonebot.adapters.onebot.v11 import Bot +from .config import global_config, llm_config +import time +import asyncio +@dataclass +class CQCode: + """ + CQ码数据类,用于存储和处理CQ码 + + 属性: + type: CQ码类型(如'image', 'at', 'face'等) + params: CQ码的参数字典 + raw_code: 原始CQ码字符串 + translated_plain_text: 经过处理(如AI翻译)后的文本表示 + """ + type: str + params: Dict[str, str] + raw_code: str + group_id: int + user_id: int + group_name: str = "" + user_nickname: str = "" + translated_plain_text: Optional[str] = None + reply_message: Dict = None # 存储回复消息 + + @classmethod + def from_cq_code(cls, cq_code: str, reply: Dict = None) -> 'CQCode': + """ + 从CQ码字符串创建CQCode对象 + 例如:[CQ:image,file=1.jpg,url=http://example.com/1.jpg] + """ + # 移除前后的[] + content = cq_code[1:-1] + # 分离类型和参数部分 + parts = content.split(',') + if not parts: + return cls('text', {'text': cq_code}, cq_code, group_id=0, user_id=0) + + # 获取CQ类型 + cq_type = parts[0][3:] # 去掉'CQ:' + + # 解析参数 + params = {} + for part in parts[1:]: + if '=' in part: + key, value = part.split('=', 1) + # 处理转义字符 + value = cls.unescape(value) + params[key] = value + + # 创建实例 + instance = cls(cq_type, params, cq_code, group_id=0, user_id=0, reply_message=reply) + # 根据类型进行相应的翻译处理 + instance.translate() + return instance + + def translate(self): + """根据CQ码类型进行相应的翻译处理""" + if self.type == 'text': + self.translated_plain_text = self.params.get('text', '') + elif self.type == 'image': + self.translated_plain_text = self.translate_image() + elif self.type == 'at': + from .message import Message + message_obj = Message( + user_id=str(self.params.get('qq', '')) + ) + self.translated_plain_text = f"@{message_obj.user_nickname}" + elif self.type == 'reply': + self.translated_plain_text = self.translate_reply() + elif self.type == 'face': + face_id = self.params.get('id', '') + # self.translated_plain_text = f"[表情{face_id}]" + self.translated_plain_text = f"[表情]" + elif self.type == 'forward': + self.translated_plain_text = self.translate_forward() + else: + self.translated_plain_text = f"[{self.type}]" + + def translate_image(self) -> str: + """处理图片类型的CQ码,区分普通图片和表情包""" + if 'url' not in self.params: + return '[图片]' + + # 获取子类型,默认为普通图片(0) + sub_type = int(self.params.get('sub_type', '0')) + is_emoji = (sub_type == 1) + + # 添加请求头 + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', + 'Accept': 'image/webp,image/apng,image/*,*/*;q=0.8', + 'Accept-Encoding': 'gzip, deflate, br', + 'Connection': 'keep-alive' + } + + # 处理URL编码问题 + url = html.unescape(self.params['url']) + + if not url.startswith(('http://', 'https://')): + raise ValueError(f"无效的URL格式: {url}") + + # 下载图片 + response = requests.get(url, headers=headers, timeout=10, verify=False) + + if response.status_code == 200: + # 检查响应内容类型 + content_type = response.headers.get('content-type', '') + if not content_type.startswith('image/'): + raise ValueError(f"响应不是图片类型: {content_type}") + + content = response.content + + image_base64 = base64.b64encode(content).decode('utf-8') + + # 根据子类型选择不同的处理方式 + if sub_type == 1: # 表情包 + return self.get_emoji_description(image_base64) + elif sub_type == 0: # 普通图片 + if self.get_image_description_is_setu(image_base64) == "是": + print(f"\033[1;34m[调试]\033[0m 哇!涩情图片") + # 使用相对路径创建目录 + # data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "data", "setu") + # os.makedirs(data_dir, exist_ok=True) + # # 生成随机文件名 + # file_name = f"{int(time.time())}_{int(random() * 10000)}.jpg" + # file_path = os.path.join(data_dir, file_name) + # # 将base64解码并保存图片 + # image_data = base64.b64decode(image_base64) + # with open(file_path, "wb") as f: + # f.write(image_data) + # print(f"\033[1;34m[调试]\033[0m 涩图已保存至: {file_path}") + + return f"[一张涩情图片]" + return self.get_image_description(image_base64) + else: # 其他类型都按普通图片处理 + return '[图片]' + else: + raise ValueError(f"下载图片失败: HTTP状态码 {response.status_code}") + + + def get_emoji_description(self, image_base64: str) -> str: + """调用AI接口获取表情包描述""" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + } + + payload = { + "model": "deepseek-ai/deepseek-vl2", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。" + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + } + } + ] + } + ], + "max_tokens": 50, + "temperature": 0.4 + } + + response = requests.post( + f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + headers=headers, + json=payload, + timeout=30 + ) + + if response.status_code == 200: + result_json = response.json() + if "choices" in result_json and len(result_json["choices"]) > 0: + description = result_json["choices"][0]["message"]["content"] + return f"[表情包:{description}]" + + raise ValueError(f"AI接口调用失败: {response.text}") + + def get_image_description(self, image_base64: str) -> str: + """调用AI接口获取普通图片描述""" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + } + + payload = { + "model": "deepseek-ai/deepseek-vl2", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + } + } + ] + } + ], + "max_tokens": 300, + "temperature": 0.6 + } + + response = requests.post( + f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + headers=headers, + json=payload, + timeout=30 + ) + + if response.status_code == 200: + result_json = response.json() + if "choices" in result_json and len(result_json["choices"]) > 0: + description = result_json["choices"][0]["message"]["content"] + return f"[图片:{description}]" + + raise ValueError(f"AI接口调用失败: {response.text}") + + + def get_image_description_is_setu(self, image_base64: str) -> str: + """调用AI接口获取普通图片描述""" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + } + + payload = { + "model": "deepseek-ai/deepseek-vl2", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "请回答我这张图片是否涉及涩情、情色、裸露或性暗示,请严格判断,有任何涩情迹象就回答是,请用是或否回答" + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + } + } + ] + } + ], + "max_tokens": 300, + "temperature": 0.6 + } + + response = requests.post( + f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + headers=headers, + json=payload, + timeout=30 + ) + + if response.status_code == 200: + result_json = response.json() + if "choices" in result_json and len(result_json["choices"]) > 0: + description = result_json["choices"][0]["message"]["content"] + # 如果描述中包含"否",返回否,其他情况返回是 + return "否" if "否" in description else "是" + + raise ValueError(f"AI接口调用失败: {response.text}") + + def translate_forward(self) -> str: + """处理转发消息""" + try: + if 'content' not in self.params: + return '[转发消息]' + + # 解析content内容(需要先反转义) + content = self.unescape(self.params['content']) + # print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}") + # 将字符串形式的列表转换为Python对象 + import ast + try: + messages = ast.literal_eval(content) + except ValueError as e: + print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}") + return '[转发消息]' + + # 处理每条消息 + formatted_messages = [] + for msg in messages: + sender = msg.get('sender', {}) + nickname = sender.get('card') or sender.get('nickname', '未知用户') + + # 获取消息内容并使用Message类处理 + raw_message = msg.get('raw_message', '') + message_array = msg.get('message', []) + + if message_array and isinstance(message_array, list): + # 检查是否包含嵌套的转发消息 + for message_part in message_array: + if message_part.get('type') == 'forward': + content = '[转发消息]' + break + else: + # 处理普通消息 + if raw_message: + from .message import Message + message_obj = Message( + user_id=msg.get('user_id', 0), + message_id=msg.get('message_id', 0), + raw_message=raw_message, + plain_text=raw_message, + group_id=msg.get('group_id', 0) + ) + content = message_obj.processed_plain_text + else: + content = '[空消息]' + else: + # 处理普通消息 + if raw_message: + from .message import Message + message_obj = Message( + user_id=msg.get('user_id', 0), + message_id=msg.get('message_id', 0), + raw_message=raw_message, + plain_text=raw_message, + group_id=msg.get('group_id', 0) + ) + content = message_obj.processed_plain_text + else: + content = '[空消息]' + + formatted_msg = f"{nickname}: {content}" + formatted_messages.append(formatted_msg) + + # 合并所有消息 + combined_messages = '\n'.join(formatted_messages) + print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}") + return f"[转发消息:\n{combined_messages}]" + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}") + return '[转发消息]' + + def translate_reply(self) -> str: + """处理回复类型的CQ码""" + + # 创建Message对象 + from .message import Message + if self.reply_message == None: + return '[回复某人消息]' + + if self.reply_message.sender.user_id: + message_obj = Message( + user_id=self.reply_message.sender.user_id, + message_id=self.reply_message.message_id, + raw_message=str(self.reply_message.message), + group_id=self.group_id + ) + if message_obj.user_id == global_config.BOT_QQ: + return f"[回复 麦麦 的消息: {message_obj.processed_plain_text}]" + else: + return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]" + + else: + return '[回复某人消息]' + + @staticmethod + def unescape(text: str) -> str: + """反转义CQ码中的特殊字符""" + return text.replace(',', ',') \ + .replace('[', '[') \ + .replace(']', ']') \ + .replace('&', '&') + + @staticmethod + def create_emoji_cq(file_path: str) -> str: + """ + 创建表情包CQ码 + Args: + file_path: 本地表情包文件路径 + Returns: + 表情包CQ码字符串 + """ + # 确保使用绝对路径 + abs_path = os.path.abspath(file_path) + # 转义特殊字符 + escaped_path = abs_path.replace('&', '&') \ + .replace('[', '[') \ + .replace(']', ']') \ + .replace(',', ',') + # 生成CQ码,设置sub_type=1表示这是表情包 + return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" + + @staticmethod + def create_reply_cq(message_id: int) -> str: + """ + 创建回复CQ码 + Args: + message_id: 回复的消息ID + Returns: + 回复CQ码字符串 + """ + return f"[CQ:reply,id={message_id}]" \ No newline at end of file diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py new file mode 100644 index 000000000..a4352758d --- /dev/null +++ b/src/plugins/chat/emoji_manager.py @@ -0,0 +1,414 @@ +from typing import List, Dict, Optional +import random +from ...common.database import Database +import os +import json +from dataclasses import dataclass +import jieba.analyse as jieba_analyse +import aiohttp +import hashlib +from datetime import datetime +import base64 +import shutil +from .config import global_config, llm_config +import asyncio +import time + + +class EmojiManager: + _instance = None + EMOJI_DIR = "data/emoji" # 表情包存储目录 + + EMOTION_KEYWORDS = { + 'happy': ['开心', '快乐', '高兴', '欢喜', '笑', '喜悦', '兴奋', '愉快', '乐', '好'], + 'angry': ['生气', '愤怒', '恼火', '不爽', '火大', '怒', '气愤', '恼怒', '发火', '不满'], + 'sad': ['伤心', '难过', '悲伤', '痛苦', '哭', '忧伤', '悲痛', '哀伤', '委屈', '失落'], + 'surprised': ['惊讶', '震惊', '吃惊', '意外', '惊', '诧异', '惊奇', '惊喜', '不敢相信', '目瞪口呆'], + 'disgusted': ['恶心', '讨厌', '厌恶', '反感', '嫌弃', '恶', '嫌恶', '憎恶', '不喜欢', '烦'], + 'fearful': ['害怕', '恐惧', '惊恐', '担心', '怕', '惊吓', '惊慌', '畏惧', '胆怯', '惧'], + 'neutral': ['普通', '一般', '还行', '正常', '平静', '平淡', '一般般', '凑合', '还好', '就这样'] + } + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.db = None + cls._instance._initialized = False + return cls._instance + + def __init__(self): + self.db = Database.get_instance() + self._scan_task = None + + def _ensure_emoji_dir(self): + """确保表情存储目录存在""" + os.makedirs(self.EMOJI_DIR, exist_ok=True) + + def initialize(self): + """初始化数据库连接和表情目录""" + if not self._initialized: + try: + self.db = Database.get_instance() + self._ensure_emoji_collection() + self._ensure_emoji_dir() + self._initialized = True + # 启动时执行一次完整性检查 + self.check_emoji_file_integrity() + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 初始化表情管理器失败: {str(e)}") + + def _ensure_db(self): + """确保数据库已初始化""" + if not self._initialized: + self.initialize() + if not self._initialized: + raise RuntimeError("EmojiManager not initialized") + + def _ensure_emoji_collection(self): + """确保emoji集合存在并创建索引""" + if 'emoji' not in self.db.db.list_collection_names(): + self.db.db.create_collection('emoji') + self.db.db.emoji.create_index([('tags', 1)]) + self.db.db.emoji.create_index([('filename', 1)], unique=True) + + def record_usage(self, emoji_id: str): + """记录表情使用次数""" + try: + self._ensure_db() + self.db.db.emoji.update_one( + {'_id': emoji_id}, + {'$inc': {'usage_count': 1}} + ) + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 记录表情使用失败: {str(e)}") + + async def _get_emotion_from_text(self, text: str) -> List[str]: + """从文本中识别情感关键词,使用DeepSeek API进行分析 + Args: + text: 输入文本 + Returns: + List[str]: 匹配到的情感标签列表 + """ + try: + # 准备请求数据 + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + } + + payload = { + "model": "deepseek-ai/DeepSeek-V3", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": f'分析这段文本:"{text}",从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签。只需要返回标签,不要输出其他任何内容。' + } + ] + } + ], + "max_tokens": 50, + "temperature": 0.3 + } + + async with aiohttp.ClientSession() as session: + async with session.post( + f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + headers=headers, + json=payload + ) as response: + if response.status != 200: + print(f"\033[1;31m[错误]\033[0m API请求失败: {await response.text()}") + return ['neutral'] + + result = json.loads(await response.text()) + if "choices" in result and len(result["choices"]) > 0: + emotion = result["choices"][0]["message"]["content"].strip().lower() + # 确保返回的标签是有效的 + if emotion in self.EMOTION_KEYWORDS: + print(f"\033[1;32m[成功]\033[0m 识别到的情感: {emotion}") + return [emotion] # 返回单个情感标签的列表 + + return ['neutral'] # 如果无法识别情感,返回neutral + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 情感分析失败: {str(e)}") + return ['neutral'] + + async def get_emoji_for_emotion(self, emotion_tag: str) -> Optional[str]: + try: + self._ensure_db() + + # 构建查询条件:标签匹配任一情感 + query = {'tags': {'$in': emotion_tag}} + + # print(f"\033[1;34m[调试]\033[0m 表情查询条件: {query}") + + try: + # 随机获取一个匹配的表情 + emoji = self.db.db.emoji.aggregate([ + {'$match': query}, + {'$sample': {'size': 1}} + ]).next() + print(f"\033[1;32m[成功]\033[0m 找到匹配的表情") + if emoji and 'path' in emoji: + # 更新使用次数 + self.db.db.emoji.update_one( + {'_id': emoji['_id']}, + {'$inc': {'usage_count': 1}} + ) + return emoji['path'] + except StopIteration: + # 如果没有匹配的表情,从所有表情中随机选择一个 + print(f"\033[1;33m[提示]\033[0m 未找到匹配的表情,随机选择一个") + try: + emoji = self.db.db.emoji.aggregate([ + {'$sample': {'size': 1}} + ]).next() + if emoji and 'path' in emoji: + # 更新使用次数 + self.db.db.emoji.update_one( + {'_id': emoji['_id']}, + {'$inc': {'usage_count': 1}} + ) + return emoji['path'] + except StopIteration: + print(f"\033[1;31m[错误]\033[0m 数据库中没有任何表情") + return None + + return None + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 获取表情包失败: {str(e)}") + return None + + + async def get_emoji_for_text(self, text: str) -> Optional[str]: + """根据文本内容获取相关表情包 + Args: + text: 输入文本 + Returns: + Optional[str]: 表情包文件路径,如果没有找到则返回None + """ + try: + self._ensure_db() + # 获取情感标签 + emotions = await self._get_emotion_from_text(text) + print("为 ‘"+ str(text) + "’ 获取到的情感标签为:" + str(emotions)) + if not emotions: + return None + + # 构建查询条件:标签匹配任一情感 + query = {'tags': {'$in': emotions}} + + print(f"\033[1;34m[调试]\033[0m 表情查询条件: {query}") + print(f"\033[1;34m[调试]\033[0m 匹配到的情感: {emotions}") + + try: + # 随机获取一个匹配的表情 + emoji = self.db.db.emoji.aggregate([ + {'$match': query}, + {'$sample': {'size': 1}} + ]).next() + print(f"\033[1;32m[成功]\033[0m 找到匹配的表情") + if emoji and 'path' in emoji: + # 更新使用次数 + self.db.db.emoji.update_one( + {'_id': emoji['_id']}, + {'$inc': {'usage_count': 1}} + ) + return emoji['path'] + except StopIteration: + # 如果没有匹配的表情,从所有表情中随机选择一个 + print(f"\033[1;33m[提示]\033[0m 未找到匹配的表情,随机选择一个") + try: + emoji = self.db.db.emoji.aggregate([ + {'$sample': {'size': 1}} + ]).next() + if emoji and 'path' in emoji: + # 更新使用次数 + self.db.db.emoji.update_one( + {'_id': emoji['_id']}, + {'$inc': {'usage_count': 1}} + ) + return emoji['path'] + except StopIteration: + print(f"\033[1;31m[错误]\033[0m 数据库中没有任何表情") + return None + + return None + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 获取表情包失败: {str(e)}") + return None + + async def _get_emoji_tag(self, image_base64: str) -> str: + """获取表情包的标签""" + async with aiohttp.ClientSession() as session: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}" + } + + payload = { + "model": "deepseek-ai/deepseek-vl2", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签,不要输出其他任何内容,只输出情感标签就好' + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + } + } + ] + } + ], + "max_tokens": 60, + "temperature": 0.3 + } + + async with session.post( + f"{llm_config.SILICONFLOW_BASE_URL}chat/completions", + headers=headers, + json=payload + ) as response: + if response.status == 200: + result = await response.json() + if "choices" in result and len(result["choices"]) > 0: + tag_result = result["choices"][0]["message"]["content"].strip().lower() + + valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"] + for tag_match in valid_tags: + if tag_match in tag_result or tag_match == tag_result: + return tag_match + print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_match}, 跳过") + else: + print(f"\033[1;31m[错误]\033[0m 获取标签失败, 状态码: {response.status}") + + print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral") + return "skip" # 默认标签 + + async def scan_new_emojis(self): + """扫描新的表情包""" + try: + emoji_dir = "data/emoji" + os.makedirs(emoji_dir, exist_ok=True) + + # 获取所有jpg文件 + files_to_process = [f for f in os.listdir(emoji_dir) if f.endswith('.jpg')] + + for filename in files_to_process: + # 检查是否已经注册过 + existing_emoji = self.db.db['emoji'].find_one({'filename': filename}) + if existing_emoji: + continue + + image_path = os.path.join(emoji_dir, filename) + # 读取图片数据 + with open(image_path, 'rb') as f: + image_data = f.read() + + # 将图片转换为base64 + image_base64 = base64.b64encode(image_data).decode('utf-8') + + # 获取表情包的情感标签 + tag = await self._get_emoji_tag(image_base64) + if not tag == "skip": + # 准备数据库记录 + emoji_record = { + 'filename': filename, + 'path': image_path, + 'tags': [tag], + 'timestamp': int(time.time()) + } + + # 保存到数据库 + self.db.db['emoji'].insert_one(emoji_record) + print(f"\033[1;32m[成功]\033[0m 注册新表情包: {filename}") + print(f"标签: {tag}") + else: + print(f"\033[1;33m[警告]\033[0m 跳过表情包: {filename}") + + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 扫描表情包失败: {str(e)}") + import traceback + print(traceback.format_exc()) + + async def _periodic_scan(self, interval_MINS: int = 10): + """定期扫描新表情包""" + while True: + print(f"\033[1;36m[表情包]\033[0m 开始扫描新表情包...") + await self.scan_new_emojis() + await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次 + + def check_emoji_file_integrity(self): + """检查表情包文件完整性 + 如果文件已被删除,则从数据库中移除对应记录 + """ + try: + self._ensure_db() + # 获取所有表情包记录 + all_emojis = list(self.db.db.emoji.find()) + removed_count = 0 + total_count = len(all_emojis) + + for emoji in all_emojis: + try: + if 'path' not in emoji: + print(f"\033[1;33m[提示]\033[0m 发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}") + self.db.db.emoji.delete_one({'_id': emoji['_id']}) + removed_count += 1 + continue + + # 检查文件是否存在 + if not os.path.exists(emoji['path']): + print(f"\033[1;33m[提示]\033[0m 表情包文件已被删除: {emoji['path']}") + # 从数据库中删除记录 + result = self.db.db.emoji.delete_one({'_id': emoji['_id']}) + if result.deleted_count > 0: + print(f"\033[1;32m[成功]\033[0m 成功删除数据库记录: {emoji['_id']}") + removed_count += 1 + else: + print(f"\033[1;31m[错误]\033[0m 删除数据库记录失败: {emoji['_id']}") + except Exception as item_error: + print(f"\033[1;31m[错误]\033[0m 处理表情包记录时出错: {str(item_error)}") + continue + + # 验证清理结果 + remaining_count = self.db.db.emoji.count_documents({}) + if removed_count > 0: + print(f"\033[1;32m[成功]\033[0m 已清理 {removed_count} 个失效的表情包记录") + print(f"\033[1;34m[统计]\033[0m 清理前总数: {total_count} | 清理后总数: {remaining_count}") + # print(f"\033[1;34m[统计]\033[0m 应删除数量: {removed_count} | 实际删除数量: {total_count - remaining_count}") + # 执行数据库压缩 + try: + self.db.db.command({"compact": "emoji"}) + print(f"\033[1;32m[成功]\033[0m 数据库集合压缩完成") + except Exception as compact_error: + print(f"\033[1;31m[错误]\033[0m 数据库压缩失败: {str(compact_error)}") + else: + print(f"\033[1;36m[表情包]\033[0m 已检查 {total_count} 个表情包记录") + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 检查表情包完整性失败: {str(e)}") + import traceback + print(f"\033[1;31m[错误追踪]\033[0m\n{traceback.format_exc()}") + + async def start_periodic_check(self, interval_MINS: int = 120): + while True: + self.check_emoji_file_integrity() + await asyncio.sleep(interval_MINS * 60) + + + +# 创建全局单例 +emoji_manager = EmojiManager() \ No newline at end of file diff --git a/src/plugins/chat/gpt_response.py b/src/plugins/chat/gpt_response.py new file mode 100644 index 000000000..8b687c1fc --- /dev/null +++ b/src/plugins/chat/gpt_response.py @@ -0,0 +1,544 @@ +from typing import Dict, Any, List, Optional, Union, Tuple +from openai import OpenAI +import asyncio +import requests +from functools import partial +from .message import Message +from .config import BotConfig +from ...common.database import Database +import random +import time +import subprocess +import os +import sys +import threading +import queue +import numpy as np +from dotenv import load_dotenv +from .relationship_manager import relationship_manager +from ..schedule.schedule_generator import bot_schedule +from .prompt_builder import prompt_builder +from .config import llm_config +from .willing_manager import willing_manager +from .utils import get_embedding +import aiohttp + + +# 获取当前文件的绝对路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) +load_dotenv(os.path.join(root_dir, '.env')) + +# 常见的错别字映射 +TYPO_DICT = { + '的': '地得', + '了': '咯啦勒', + '吗': '嘛麻', + '吧': '八把罢', + '是': '事', + '在': '再在', + '和': '合', + '有': '又', + '我': '沃窝喔', + '你': '泥尼拟', + '他': '它她塔祂', + '们': '门', + '啊': '阿哇', + '呢': '呐捏', + '都': '豆读毒', + '很': '狠', + '会': '回汇', + '去': '趣取曲', + '做': '作坐', + '想': '相像', + '说': '说税睡', + '看': '砍堪刊', + '来': '来莱赖', + '好': '号毫豪', + '给': '给既继', + '过': '锅果裹', + '能': '嫩', + '为': '位未', + '什': '甚深伸', + '么': '末麽嘛', + '话': '话花划', + '知': '织直值', + '道': '到', + '听': '听停挺', + '见': '见件建', + '觉': '觉脚搅', + '得': '得德锝', + '着': '着找招', + '像': '向象想', + '等': '等灯登', + '谢': '谢写卸', + '对': '对队', + '里': '里理鲤', + '啦': '啦拉喇', + '吃': '吃持迟', + '哦': '哦喔噢', + '呀': '呀压', + '要': '药', + '太': '太抬台', + '快': '块', + '点': '店', + '以': '以已', + '因': '因应', + '啥': '啥沙傻', + '行': '行型形', + '哈': '哈蛤铪', + '嘿': '嘿黑嗨', + '嗯': '嗯恩摁', + '哎': '哎爱埃', + '呜': '呜屋污', + '喂': '喂位未', + '嘛': '嘛麻马', + '嗨': '嗨害亥', + '哇': '哇娃蛙', + '咦': '咦意易', + '嘻': '嘻西希' +} + +def random_remove_punctuation(text: str) -> str: + """随机处理标点符号,模拟人类打字习惯""" + result = '' + text_len = len(text) + + for i, char in enumerate(text): + if char == '。' and i == text_len - 1: # 结尾的句号 + if random.random() > 0.4: # 80%概率删除结尾句号 + continue + elif char == ',': + rand = random.random() + if rand < 0.25: # 5%概率删除逗号 + continue + elif rand < 0.25: # 20%概率把逗号变成空格 + result += ' ' + continue + result += char + return result + +def add_typos(text: str) -> str: + """随机给文本添加错别字""" + TYPO_RATE = 0.02 # 控制错别字出现的概率(1%) + + result = "" + for char in text: + if char in TYPO_DICT and random.random() < TYPO_RATE: + # 从可能的错别字中随机选择一个 + typos = TYPO_DICT[char] + result += random.choice(typos) + else: + result += char + return result + +def open_new_console_window(text: str): + """在新的控制台窗口中显示文本""" + if sys.platform == 'win32': + # 创建一个临时批处理文件 + temp_bat = "temp_output.bat" + with open(temp_bat, "w", encoding="utf-8") as f: + f.write(f'@echo off\n') + f.write(f'echo {text}\n') + f.write('pause\n') + + # 在新窗口中运行批处理文件 + subprocess.Popen(['start', 'cmd', '/c', temp_bat], shell=True) + + # 等待一会儿再删除批处理文件 + import threading + def delete_bat(): + import time + time.sleep(2) + if os.path.exists(temp_bat): + os.remove(temp_bat) + threading.Thread(target=delete_bat).start() + +class ReasoningWindow: + def __init__(self): + self.process = None + self.message_queue = queue.Queue() + self.is_running = False + self.content_file = "reasoning_content.txt" + + def start(self): + if self.process is None: + # 创建用于显示的批处理文件 + with open("reasoning_window.bat", "w", encoding="utf-8") as f: + f.write('@echo off\n') + f.write('chcp 65001\n') # 设置UTF-8编码 + f.write('title Magellan Reasoning Process\n') + f.write('echo Waiting for reasoning content...\n') + f.write(':loop\n') + f.write('if exist "reasoning_update.txt" (\n') + f.write(' type "reasoning_update.txt" >> "reasoning_content.txt"\n') + f.write(' del "reasoning_update.txt"\n') + f.write(' cls\n') + f.write(' type "reasoning_content.txt"\n') + f.write(')\n') + f.write('timeout /t 1 /nobreak >nul\n') + f.write('goto loop\n') + + # 清空内容文件 + with open(self.content_file, "w", encoding="utf-8") as f: + f.write("") + + # 启动新窗口 + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + self.process = subprocess.Popen(['cmd', '/c', 'start', 'reasoning_window.bat'], + shell=True, + startupinfo=startupinfo) + self.is_running = True + + # 启动处理线程 + threading.Thread(target=self._process_messages, daemon=True).start() + + def _process_messages(self): + while self.is_running: + try: + # 获取新消息 + text = self.message_queue.get(timeout=1) + # 写入更新文件 + with open("reasoning_update.txt", "w", encoding="utf-8") as f: + f.write(text) + except queue.Empty: + continue + except Exception as e: + print(f"处理推理内容时出错: {e}") + + def update_content(self, text: str): + if self.is_running: + self.message_queue.put(text) + + def stop(self): + self.is_running = False + if self.process: + self.process.terminate() + self.process = None + # 清理文件 + for file in ["reasoning_window.bat", "reasoning_content.txt", "reasoning_update.txt"]: + if os.path.exists(file): + os.remove(file) + +# 创建全局单例 +reasoning_window = ReasoningWindow() + +class GPTResponseGenerator: + def __init__(self, config: BotConfig): + self.config = config + self.client = OpenAI( + api_key=llm_config.SILICONFLOW_API_KEY, + base_url=llm_config.SILICONFLOW_BASE_URL + ) + + self.db = Database.get_instance() + reasoning_window.start() + # 当前使用的模型类型 + self.current_model_type = 'r1' # 默认使用 R1 + + async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]: + """根据当前模型类型选择对应的生成函数""" + # 使用随机数选择模型 + rand = random.random() + if rand < 0.15: # 40%概率使用 R1 + self.current_model_type = "r1" + elif rand < 0.8: # 30%概率使用 V3 + self.current_model_type = "v3" + else: # 30%概率使用 R1-Distill + self.current_model_type = "r1_distill" + + print(f"+++++++++++++++++麦麦{self.current_model_type}思考中+++++++++++++++++") + if self.current_model_type == 'r1': + model_response = await self._generate_r1_response(message) + elif self.current_model_type == 'v3': + model_response = await self._generate_v3_response(message) + else: + model_response = await self._generate_r1_distill_response(message) + + # 打印情感标签 + print(f'麦麦的回复是:{model_response}') + model_response , emotion = await self._process_response(model_response) + + if model_response: + print(f"为 '{model_response}' 获取到的情感标签为:{emotion}") + + return model_response,emotion + + async def _generate_r1_response(self, message: Message) -> Optional[Tuple[Union[str, List[str]], List[str]]]: + """使用 DeepSeek-R1 模型生成回复""" + # 获取群聊上下文 + group_chat = await self._get_group_chat_context(message) + sender_name = message.user_nickname or f"用户{message.user_id}" + if relationship_manager.get_relationship(message.user_id): + relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value + print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") + else: + relationship_value = 0.0 + + # 构建 prompt + prompt = prompt_builder._build_prompt( + message_txt=message.processed_plain_text, + sender_name=sender_name, + relationship_value=relationship_value, + group_id=message.group_id + ) + + def create_completion(): + return self.client.chat.completions.create( + model="Pro/deepseek-ai/DeepSeek-R1", + messages=[{"role": "user", "content": prompt}], + stream=False, + max_tokens=1024 + ) + + loop = asyncio.get_event_loop() + response = await loop.run_in_executor(None, create_completion) + if response.choices[0].message.content: + print(response.choices[0].message.content) + print(response.choices[0].message.reasoning_content) + # 处理 R1 特有的返回格式 + content = response.choices[0].message.content + reasoning_content = response.choices[0].message.reasoning_content + else: + return None + # 更新推理窗口 + self._update_reasoning_window(message, prompt, reasoning_content, content, sender_name) + + return content + + async def _generate_v3_response(self, message: Message) -> Optional[Tuple[Union[str, List[str]], List[str]]]: + """使用 DeepSeek-V3 模型生成回复""" + # 获取群聊上下文 + group_chat = await self._get_group_chat_context(message) + sender_name = message.user_nickname or f"用户{message.user_id}" + + if relationship_manager.get_relationship(message.user_id): + relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value + print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") + else: + relationship_value = 0.0 + + prompt = prompt_builder._build_prompt(message.processed_plain_text, sender_name,relationship_value,group_id=message.group_id) + + messages = [{"role": "user", "content": prompt}] + + loop = asyncio.get_event_loop() + create_completion = partial( + self.client.chat.completions.create, + model="Pro/deepseek-ai/DeepSeek-V3", + messages=messages, + stream=False, + max_tokens=1024, + temperature=0.8 + ) + response = await loop.run_in_executor(None, create_completion) + + if response.choices[0].message.content: + content = response.choices[0].message.content + # V3 模型没有 reasoning_content + self._update_reasoning_window(message, prompt, "V3模型无推理过程", content, sender_name) + return content + else: + print(f"[ERROR] V3 回复发送生成失败: {response}") + + return None, [] # 返回元组 + + async def _generate_r1_distill_response(self, message: Message) -> Optional[Tuple[Union[str, List[str]], List[str]]]: + """使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复""" + # 获取群聊上下文 + group_chat = await self._get_group_chat_context(message) + sender_name = message.user_nickname or f"用户{message.user_id}" + if relationship_manager.get_relationship(message.user_id): + relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value + print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") + else: + relationship_value = 0.0 + + # 构建 prompt + prompt = prompt_builder._build_prompt( + message_txt=message.processed_plain_text, + sender_name=sender_name, + relationship_value=relationship_value, + group_id=message.group_id + ) + + def create_completion(): + return self.client.chat.completions.create( + model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + messages=[{"role": "user", "content": prompt}], + stream=False, + max_tokens=1024 + ) + + loop = asyncio.get_event_loop() + response = await loop.run_in_executor(None, create_completion) + if response.choices[0].message.content: + print(response.choices[0].message.content) + print(response.choices[0].message.reasoning_content) + # 处理 R1 特有的返回格式 + content = response.choices[0].message.content + reasoning_content = response.choices[0].message.reasoning_content + else: + return None + # 更新推理窗口 + self._update_reasoning_window(message, prompt, reasoning_content, content, sender_name) + + return content + + async def _get_group_chat_context(self, message: Message) -> str: + """获取群聊上下文""" + recent_messages = self.db.db.messages.find( + {"group_id": message.group_id} + ).sort("time", -1).limit(15) + + messages_list = list(recent_messages)[::-1] + group_chat = "" + + for msg_dict in messages_list: + time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(msg_dict['time'])) + display_name = msg_dict.get('user_nickname', f"用户{msg_dict['user_id']}") + content = msg_dict.get('processed_plain_text', msg_dict['plain_text']) + + group_chat += f"[{time_str}] {display_name}: {content}\n" + + return group_chat + + def _update_reasoning_window(self, message, prompt, reasoning_content, content, sender_name): + """更新推理窗口内容""" + current_time = time.strftime("%Y-%m-%d %H:%M:%S") + + # 获取当前使用的模型名称 + model_name = { + 'r1': 'DeepSeek-R1', + 'v3': 'DeepSeek-V3', + 'r1_distill': 'DeepSeek-R1-Distill-Qwen-32B' + }.get(self.current_model_type, '未知模型') + + display_text = ( + f"Time: {current_time}\n" + f"Group: {message.group_name}\n" + f"User: {sender_name}\n" + f"Model: {model_name}\n" + f"\033[1;32mMessage:\033[0m {message.processed_plain_text}\n\n" + f"\033[1;32mPrompt:\033[0m \n{prompt}\n" + f"\n-------------------------------------------------------" + f"\n\033[1;32mReasoning Process:\033[0m\n{reasoning_content}\n" + f"\n\033[1;32mResponse Content:\033[0m\n{content}\n" + f"\n{'='*50}\n" + ) + reasoning_window.update_content(display_text) + + async def _get_emotion_tags(self, content: str) -> List[str]: + """提取情感标签""" + try: + prompt = f'''请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出 + 只输出标签就好,不要输出其他内容: + 内容:{content} + 输出: + ''' + + messages = [{"role": "user", "content": prompt}] + + loop = asyncio.get_event_loop() + create_completion = partial( + self.client.chat.completions.create, + model="Pro/deepseek-ai/DeepSeek-V3", + messages=messages, + stream=False, + max_tokens=30, + temperature=0.6 + ) + response = await loop.run_in_executor(None, create_completion) + + if response.choices[0].message.content: + # 确保返回的是列表格式 + emotion_tag = response.choices[0].message.content.strip() + return [emotion_tag] # 将单个标签包装成列表返回 + + return ["neutral"] # 如果无法获取情感标签,返回默认值 + + except Exception as e: + print(f"获取情感标签时出错: {e}") + return ["neutral"] # 发生错误时返回默认值 + + async def _process_response(self, content: str) -> Tuple[Union[str, List[str]], List[str]]: + """处理响应内容,返回处理后的内容和情感标签""" + if not content: + return None, [] + + emotion_tags = await self._get_emotion_tags(content) + + # 添加错别字和处理标点符号 + if random.random() < 0.9: # 90%概率进行处理 + processed_response = random_remove_punctuation(add_typos(content)) + else: + processed_response = content + # 处理长消息 + if len(processed_response) > 5: + sentences = self._split_into_sentences(processed_response) + print(f"分割后的句子: {sentences}") + messages = [] + current_message = "" + + for sentence in sentences: + if len(current_message) + len(sentence) <= 5: + current_message += ' ' + current_message += sentence + else: + if current_message: + messages.append(current_message.strip()) + current_message = sentence + + if current_message: + messages.append(current_message.strip()) + + # 翻转消息顺序 + # messages.reverse() + + return messages, emotion_tags + + return processed_response, emotion_tags + + def _split_into_sentences(self, text: str) -> List[str]: + """将文本分割成句子,但保持书名号中的内容完整""" + delimiters = ['。', '!', ',', ',', '?', '…', '!', '?', '\n'] # 添加换行符作为分隔符 + remove_chars = [',', ','] # 只移除这两种逗号 + sentences = [] + current_sentence = "" + in_book_title = False # 标记是否在书名号内 + + for char in text: + current_sentence += char + + # 检查书名号 + if char == '《': + in_book_title = True + elif char == '》': + in_book_title = False + + # 只有不在书名号内且是分隔符时才分割 + if char in delimiters and not in_book_title: + if current_sentence.strip(): # 确保不是空字符串 + # 只移除逗号 + clean_sentence = current_sentence + if clean_sentence[-1] in remove_chars: + clean_sentence = clean_sentence[:-1] + if clean_sentence.strip(): + sentences.append(clean_sentence.strip()) + current_sentence = "" + + # 处理最后一个句子 + if current_sentence.strip(): + # 如果最后一个字符是逗号,移除它 + if current_sentence[-1] in remove_chars: + current_sentence = current_sentence[:-1] + sentences.append(current_sentence.strip()) + + # 过滤掉空字符串 + sentences = [s for s in sentences if s.strip()] + + return sentences + + +# llm_response = GPTResponseGenerator(config=BotConfig()) \ No newline at end of file diff --git a/src/plugins/chat/group_info_manager.py b/src/plugins/chat/group_info_manager.py new file mode 100644 index 000000000..52cf06138 --- /dev/null +++ b/src/plugins/chat/group_info_manager.py @@ -0,0 +1,107 @@ +from typing import Dict, Optional +from ...common.database import Database +import time + +class GroupInfoManager: + def __init__(self): + self.db = Database.get_instance() + # 确保必要的集合存在 + self._ensure_collections() + + def _ensure_collections(self): + """确保数据库中有必要的集合""" + collections = self.db.db.list_collection_names() + if 'group_info' not in collections: + self.db.db.create_collection('group_info') + if 'user_info' not in collections: + self.db.db.create_collection('user_info') + + async def update_group_info(self, group_id: int, group_name: str, group_notice: str = "", + member_count: int = 0, admins: list = None): + """更新群组信息""" + try: + group_data = { + "group_id": group_id, + "group_name": group_name, + "group_notice": group_notice, + "member_count": member_count, + "admins": admins or [], + "last_updated": time.time() + } + + # 使用 upsert 来更新或插入数据 + self.db.db.group_info.update_one( + {"group_id": group_id}, + {"$set": group_data}, + upsert=True + ) + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 更新群信息失败: {str(e)}") + + async def update_user_info(self, user_id: int, nickname: str, group_id: int = None, + group_card: str = None, age: int = None, gender: str = None, + location: str = None): + """更新用户信息""" + try: + # 基础用户数据 + user_data = { + "user_id": user_id, + "nickname": nickname, + "last_updated": time.time() + } + + # 添加可选字段 + if age is not None: + user_data["age"] = age + if gender is not None: + user_data["gender"] = gender + if location is not None: + user_data["location"] = location + + # 如果提供了群相关信息,更新用户在该群的信息 + if group_id is not None: + group_info_key = f"group_info.{group_id}" + group_data = { + group_info_key: { + "group_card": group_card, + "last_active": time.time() + } + } + user_data.update(group_data) + + # 使用 upsert 来更新或插入数据 + result = self.db.db.user_info.update_one( + {"user_id": user_id}, + { + "$set": user_data, + "$addToSet": {"groups": group_id} if group_id else {} + }, + upsert=True + ) + + # print(f"\033[1;32m[用户信息]\033[0m 更新用户 {nickname}({user_id}) 的信息 {'成功' if result.modified_count > 0 or result.upserted_id else '未变化'}") + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 更新用户信息失败: {str(e)}") + print(f"用户ID: {user_id}, 昵称: {nickname}, 群ID: {group_id}, 群名片: {group_card}") + + async def get_group_info(self, group_id: int) -> Optional[Dict]: + """获取群组信息""" + try: + return self.db.db.group_info.find_one({"group_id": group_id}) + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 获取群信息失败: {str(e)}") + return None + + async def get_user_info(self, user_id: int, group_id: int = None) -> Optional[Dict]: + """获取用户信息""" + try: + user_info = self.db.db.user_info.find_one({"user_id": user_id}) + if user_info and group_id: + # 添加该用户在特定群的信息 + group_info_key = f"group_info.{group_id}" + user_info["current_group_info"] = user_info.get(group_info_key, {}) + return user_info + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 获取用户信息失败: {str(e)}") + return None \ No newline at end of file diff --git a/src/plugins/chat/image_utils.py b/src/plugins/chat/image_utils.py new file mode 100644 index 000000000..e2fedf607 --- /dev/null +++ b/src/plugins/chat/image_utils.py @@ -0,0 +1,162 @@ +import io +from PIL import Image +import hashlib +import time +import os +from ...common.database import Database +from .config import BotConfig +import zlib # 用于 CRC32 +import base64 + +bot_config = BotConfig.load_config() + +def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes: + """ + 压缩图片到指定大小(单位:KB)并在数据库中记录图片信息 + Args: + image_data: 图片字节数据 + group_id: 群组ID + user_id: 用户ID + max_size: 最大文件大小(KB) + """ + try: + # 使用 CRC32 计算哈希值 + hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x') + + # 确保图片目录存在 + images_dir = "data/images" + os.makedirs(images_dir, exist_ok=True) + + # 连接数据库 + db = Database( + host=bot_config.MONGODB_HOST, + port=bot_config.MONGODB_PORT, + db_name=bot_config.DATABASE_NAME + ) + + # 检查是否已存在相同哈希值的图片 + collection = db.db['images'] + existing_image = collection.find_one({'hash': hash_value}) + + if existing_image: + print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}") + return image_data + + # 将字节数据转换为图片对象 + img = Image.open(io.BytesIO(image_data)) + + # 如果是动图,直接返回原图 + if getattr(img, 'is_animated', False): + return image_data + + # 计算当前大小(KB) + current_size = len(image_data) / 1024 + + # 如果已经小于目标大小,直接使用原图 + if current_size <= max_size: + compressed_data = image_data + else: + # 压缩逻辑 + # 先缩放到50% + new_width = int(img.width * 0.5) + new_height = int(img.height * 0.5) + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # 如果缩放后的最大边长仍然大于400,继续缩放 + max_dimension = 400 + max_current = max(new_width, new_height) + if max_current > max_dimension: + ratio = max_dimension / max_current + new_width = int(new_width * ratio) + new_height = int(new_height * ratio) + img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # 转换为RGB模式(去除透明通道) + if img.mode in ('RGBA', 'P'): + img = img.convert('RGB') + + # 使用固定质量参数压缩 + output = io.BytesIO() + img.save(output, format='JPEG', quality=85, optimize=True) + compressed_data = output.getvalue() + + # 生成文件名(使用时间戳和哈希值确保唯一性) + timestamp = int(time.time()) + filename = f"{timestamp}_{hash_value}.jpg" + image_path = os.path.join(images_dir, filename) + + # 保存文件 + with open(image_path, "wb") as f: + f.write(compressed_data) + + print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}") + + try: + # 准备数据库记录 + image_record = { + 'filename': filename, + 'path': image_path, + 'size': len(compressed_data) / 1024, + 'timestamp': timestamp, + 'width': img.width, + 'height': img.height, + 'description': '', + 'tags': [], + 'type': 'image', + 'hash': hash_value + } + + # 保存记录 + collection.insert_one(image_record) + print(f"\033[1;32m[成功]\033[0m 保存图片记录到数据库") + + except Exception as db_error: + print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}") + + return compressed_data + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}") + import traceback + print(traceback.format_exc()) + return image_data + +def storage_emoji(image_data: bytes) -> bytes: + """ + 存储表情包到本地文件夹 + Args: + image_data: 图片字节数据 + group_id: 群组ID(仅用于日志) + user_id: 用户ID(仅用于日志) + Returns: + bytes: 原始图片数据 + """ + try: + # 使用 CRC32 计算哈希值 + hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x') + + # 确保表情包目录存在 + emoji_dir = "data/emoji" + os.makedirs(emoji_dir, exist_ok=True) + + # 检查是否已存在相同哈希值的文件 + for filename in os.listdir(emoji_dir): + if hash_value in filename: + # print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}") + return image_data + + # 生成文件名 + timestamp = int(time.time()) + filename = f"{timestamp}_{hash_value}.jpg" + emoji_path = os.path.join(emoji_dir, filename) + + # 直接保存原始文件 + with open(emoji_path, "wb") as f: + f.write(image_data) + + print(f"\033[1;32m[成功]\033[0m 保存表情包到: {emoji_path}") + return image_data + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 保存表情包失败: {str(e)}") + return image_data \ No newline at end of file diff --git a/src/plugins/chat/info_gui.py b/src/plugins/chat/info_gui.py new file mode 100644 index 000000000..582efc9e5 --- /dev/null +++ b/src/plugins/chat/info_gui.py @@ -0,0 +1,76 @@ +import gradio as gr +import time +import threading +from typing import Dict, List +from .message import Message + +class MessageWindow: + def __init__(self): + self.interface = None + self._running = False + self.messages_history = [] + + def _create_window(self): + """创建Gradio界面""" + with gr.Blocks(title="实时消息监控") as self.interface: + with gr.Row(): + with gr.Column(): + self.message_box = gr.Dataframe( + headers=["时间", "群号", "发送者", "消息内容"], + datatype=["str", "str", "str", "str"], + row_count=20, + col_count=(4, "fixed"), + interactive=False, + wrap=True + ) + + # 每1秒自动刷新 + self.interface.load(self._update_display, None, [self.message_box], every=1) + + # 启动界面 + self.interface.queue() + self._running = True + self.interface.launch(share=False, server_port=7860) + + def _update_display(self): + """更新消息显示""" + display_data = [] + for msg in self.messages_history[-1000:]: # 只显示最近1000条消息 + time_str = time.strftime("%H:%M:%S", time.localtime(msg["time"])) + display_data.append([ + time_str, + str(msg["group_id"]), + f"{msg['user_nickname']}({msg['user_id']})", + msg["plain_text"] + ]) + return display_data + + def update_messages(self, group_id: int, messages: List[Message]): + """接收新消息更新""" + for msg in messages: + self.messages_history.append({ + "time": msg.time, + "group_id": group_id, + "user_id": msg.user_id, + "user_nickname": msg.user_nickname, + "plain_text": msg.plain_text + }) + + # 保持最多存储1000条消息 + if len(self.messages_history) > 1000: + self.messages_history = self.messages_history[-1000:] + + def start(self): + """启动窗口""" + # 在新线程中启动窗口 + threading.Thread(target=self._create_window, daemon=True).start() + + def stop(self): + """停止窗口""" + self._running = False + if self.interface: + self.interface.close() + +# 创建全局实例 +message_window = MessageWindow() + diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py new file mode 100644 index 000000000..bf7f1b644 --- /dev/null +++ b/src/plugins/chat/llm_generator.py @@ -0,0 +1,108 @@ +from typing import Dict, Any, List, Optional, Union, Tuple +from openai import OpenAI +from functools import partial +from .config import BotConfig +from ...common.database import Database +import random +import os +import aiohttp +from dotenv import load_dotenv +from .relationship_manager import relationship_manager + +# 获取当前文件的绝对路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) +load_dotenv(os.path.join(root_dir, '.env')) + + + +class LLMResponseGenerator: + def __init__(self, config: BotConfig): + self.config = config + self.API_KEY = os.getenv('SILICONFLOW_KEY') + self.BASE_URL =os.getenv('SILICONFLOW_BASE_URL') + self.client = OpenAI( + api_key=self.API_KEY, + base_url=self.BASE_URL + ) + + self.db = Database.get_instance() + # 当前使用的模型类型 + self.current_model_type = 'r1' # 默认使用 R1 + + async def generate_response(self, text: str) -> Optional[str]: + """根据当前模型类型选择对应的生成函数""" + if random.random() < self.config.MODEL_R1_PROBABILITY: + self.current_model_type = "r1" + else: + self.current_model_type = "v3" + + print(f"+++++++++++++++++麦麦{self.current_model_type}思考中+++++++++++++++++") + if self.current_model_type == 'r1': + model_response = await self._generate_v3_response(text) + else: + model_response = await self._generate_v3_response(text) + # 打印情感标签 + print(f'麦麦的回复------------------------------是:{model_response}') + + return model_response + + async def _generate_r1_response(self, text: str) -> Optional[str]: + """使用 DeepSeek-R1 模型生成回复""" + messages = [{"role": "user", "content": text}] + async with aiohttp.ClientSession() as session: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.API_KEY}" + } + payload = { + "model": "Pro/deepseek-ai/DeepSeek-R1", + "messages": messages, + "stream": False, + "max_tokens": 1024, + "temperature": 0.8 + } + async with session.post(f"{self.BASE_URL}/chat/completions", + headers=headers, + json=payload) as response: + result = await response.json() + if "choices" in result and len(result["choices"]) > 0: + content = result["choices"][0]["message"]["content"] + reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") + print(f"Content: {content}") + print(f"Reasoning: {reasoning_content}") + return content + + return None + + async def _generate_v3_response(self, text: str) -> Optional[str]: + """使用 DeepSeek-V3 模型生成回复""" + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.API_KEY}" + } + + payload = { + "model": "Pro/deepseek-ai/DeepSeek-V3", + "messages": [{"role": "user", "content": text}], + "max_tokens": 1024, + "temperature": 0.8 + } + + async with aiohttp.ClientSession() as session: + async with session.post(f"{self.BASE_URL}/chat/completions", + headers=headers, + json=payload) as response: + result = await response.json() + + if "choices" in result and len(result["choices"]) > 0: + content = result["choices"][0]["message"]["content"] + return content + else: + print(f"[ERROR] V3 回复发送生成失败: {result}") + + return None + + +llm_response = LLMResponseGenerator(config=BotConfig()) \ No newline at end of file diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py new file mode 100644 index 000000000..52e8443a6 --- /dev/null +++ b/src/plugins/chat/message.py @@ -0,0 +1,318 @@ +from dataclasses import dataclass +from typing import List, Optional, Dict, Tuple, ForwardRef +import time +import jieba.analyse as jieba_analyse +import os +from datetime import datetime +from ...common.database import Database +from PIL import Image +from .config import BotConfig, global_config +import urllib3 +from .cq_code import CQCode + +Message = ForwardRef('Message') # 添加这行 + +# 加载配置 +bot_config = BotConfig.load_config() + +# 禁用SSL警告 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +#这个类是消息数据类,用于存储和管理消息数据。 +#它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。 +#它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。 + + +@dataclass +class Message: + """消息数据类""" + group_id: int = None + user_id: int = None + user_nickname: str = None # 用户昵称 + group_name: str = None # 群名称 + + message_id: int = None + raw_message: str = None + plain_text: str = None + + message_based_id: int = None + reply_message: Dict = None # 存储回复消息 + + message_segments: List[Dict] = None # 存储解析后的消息片段 + processed_plain_text: str = None # 用于存储处理后的plain_text + + time: float = None + + is_emoji: bool = False # 是否是表情包 + + + + reply_benefits: float = 0.0 + + type: str = 'received' # 消息类型,可以是received或者send + + + + """消息数据类:思考消息""" + + # 思考状态相关属性 + is_thinking: bool = False + thinking_text: str = "正在思考..." + thingking_start_time: float = None + thinking_time: float = 0 + + received_message = '' + thinking_response = '' + + def __post_init__(self): + if self.time is None: + self.time = int(time.time()) + + if not self.user_nickname: + self.user_nickname = self.get_user_nickname(self.user_id) + + if not self.group_name: + self.group_name = self.get_groupname(self.group_id) + + if not self.processed_plain_text: + # 解析消息片段 + if self.raw_message: + # print(f"\033[1;34m[调试信息]\033[0m 原始消息: {self.raw_message}") + self.message_segments = self.parse_message_segments(str(self.raw_message)) + self.processed_plain_text = ' '.join( + seg['translated_text'] + for seg in self.message_segments + ) + + # print(f"\033[1;34m[调试]\033[0m pppttt消息: {self.processed_plain_text}") + def get_user_nickname(self, user_id: int) -> str: + """ + 根据user_id获取用户昵称 + 如果数据库中找不到,则返回默认昵称 + """ + if not user_id: + return "未知用户" + + user_id = int(user_id) + if user_id == int(global_config.BOT_QQ): + return "麦麦" + + # 使用数据库单例 + db = Database.get_instance() + # 查找用户,打印查询条件和结果 + query = {'user_id': user_id} + user = db.db.user_info.find_one(query) + if user: + return user.get('nickname') or f"用户{user_id}" + else: + return f"用户{user_id}" + + def get_groupname(self, group_id: int) -> str: + if not group_id: + return "未知群" + group_id = int(group_id) + # 使用数据库单例 + db = Database.get_instance() + # 查找用户,打印查询条件和结果 + query = {'group_id': group_id} + group = db.db.group_info.find_one(query) + if group: + return group.get('group_name') + else: + return f"群{group_id}" + + def parse_message_segments(self, message: str) -> List[Dict]: + """ + 将消息解析为片段列表,包括纯文本和CQ码 + 返回的列表中每个元素都是字典,包含: + - type: 'text' 或 CQ码类型 + - data: 对于text类型是文本内容,对于CQ码是参数字典 + - translated_text: 经过处理(如AI翻译)后的文本 + """ + segments = [] + start = 0 + + while True: + # 查找下一个CQ码的开始位置 + cq_start = message.find('[CQ:', start) + if cq_start == -1: + # 如果没有找到更多CQ码,添加剩余文本 + if start < len(message): + text = message[start:].strip() + if text: # 只添加非空文本 + segments.append({ + 'type': 'text', + 'data': {'text': text}, + 'translated_text': text + }) + break + + # 添加CQ码前的文本 + if cq_start > start: + text = message[start:cq_start].strip() + if text: # 只添加非空文本 + segments.append({ + 'type': 'text', + 'data': {'text': text}, + 'translated_text': text + }) + + # 查找CQ码的结束位置 + cq_end = message.find(']', cq_start) + if cq_end == -1: + # CQ码未闭合,作为普通文本处理 + text = message[cq_start:].strip() + if text: + segments.append({ + 'type': 'text', + 'data': {'text': text}, + 'translated_text': text + }) + break + + # 提取完整的CQ码并创建CQCode对象 + cq_code = message[cq_start:cq_end + 1] + try: + cq_obj = CQCode.from_cq_code(cq_code,reply = self.reply_message) + # 设置必要的属性 + segments.append({ + 'type': cq_obj.type, + 'data': cq_obj.params, + 'translated_text': cq_obj.translated_plain_text + }) + except Exception as e: + import traceback + print(f"\033[1;31m[错误]\033[0m 处理CQ码失败: {str(e)}") + print(f"CQ码内容: {cq_code}") + print(f"当前消息属性:") + print(f"- group_id: {self.group_id}") + print(f"- user_id: {self.user_id}") + print(f"- user_nickname: {self.user_nickname}") + print(f"- group_name: {self.group_name}") + print("详细错误信息:") + print(traceback.format_exc()) + # 处理失败时,将CQ码作为普通文本处理 + segments.append({ + 'type': 'text', + 'data': {'text': cq_code}, + 'translated_text': cq_code + }) + + start = cq_end + 1 + + # 检查是否只包含一个表情包CQ码 + if len(segments) == 1 and segments[0]['type'] == 'image': + # 检查图片的 subtype 是否为 0(表情包) + if segments[0]['data'].get('subtype') == '0': + self.is_emoji = True + + return segments + +class Message_Thinking: + """消息思考类""" + def __init__(self, message: Message,message_id: str): + # 复制原始消息的基本属性 + self.group_id = message.group_id + self.user_id = message.user_id + self.user_nickname = message.user_nickname + self.group_name = message.group_name + + self.message_id = message_id + + # 思考状态相关属性 + self.thinking_text = "正在思考..." + self.time = int(time.time()) + + def update_to_message(self, done_message: Message) -> Message: + """更新为完整消息""" + + return done_message + + @property + def processed_plain_text(self) -> str: + """获取处理后的文本""" + return self.thinking_text + + def __str__(self) -> str: + return f"[思考中] 群:{self.group_id} 用户:{self.user_nickname} 时间:{self.time} 消息ID:{self.message_id}" + + +class MessageSet: + """消息集合类,可以存储多个相关的消息""" + def __init__(self, group_id: int, user_id: int, message_id: str): + self.group_id = group_id + self.user_id = user_id + self.message_id = message_id + self.messages: List[Message] = [] + self.time = round(time.time(), 2) + + def add_message(self, message: Message) -> None: + """添加消息到集合""" + self.messages.append(message) + # 按时间排序 + self.messages.sort(key=lambda x: x.time) + + def get_message_by_index(self, index: int) -> Optional[Message]: + """通过索引获取消息""" + if 0 <= index < len(self.messages): + return self.messages[index] + return None + + def get_message_by_time(self, target_time: float) -> Optional[Message]: + """获取最接近指定时间的消息""" + if not self.messages: + return None + + # 使用二分查找找到最接近的消息 + left, right = 0, len(self.messages) - 1 + while left < right: + mid = (left + right) // 2 + if self.messages[mid].time < target_time: + left = mid + 1 + else: + right = mid + + return self.messages[left] + + def get_latest_message(self) -> Optional[Message]: + """获取最新的消息""" + return self.messages[-1] if self.messages else None + + def get_earliest_message(self) -> Optional[Message]: + """获取最早的消息""" + return self.messages[0] if self.messages else None + + def get_all_messages(self) -> List[Message]: + """获取所有消息""" + return self.messages.copy() + + def get_message_count(self) -> int: + """获取消息数量""" + return len(self.messages) + + def clear_messages(self) -> None: + """清空所有消息""" + self.messages.clear() + + def remove_message(self, message: Message) -> bool: + """移除指定消息""" + if message in self.messages: + self.messages.remove(message) + return True + return False + + def __str__(self) -> str: + return f"MessageSet(id={self.message_id}, count={len(self.messages)})" + + def __len__(self) -> int: + return len(self.messages) + + @property + def processed_plain_text(self) -> str: + """获取所有消息的文本内容""" + return "\n".join(msg.processed_plain_text for msg in self.messages if msg.processed_plain_text) + + + + + diff --git a/src/plugins/chat/message_send_control.py b/src/plugins/chat/message_send_control.py new file mode 100644 index 000000000..9ac1189c4 --- /dev/null +++ b/src/plugins/chat/message_send_control.py @@ -0,0 +1,322 @@ +from typing import Union, List, Optional, Deque, Dict +from nonebot.adapters.onebot.v11 import Bot, MessageSegment +import asyncio +import random +from .message import Message, Message_Thinking, MessageSet +from .cq_code import CQCode +from collections import deque +import time +from .storage import MessageStorage # 添加这行导入 + + +class SendTemp: + """单个群组的临时消息队列管理器""" + def __init__(self, group_id: int, max_size: int = 100): + self.group_id = group_id + self.max_size = max_size + self.messages: Deque[Union[Message, Message_Thinking]] = deque(maxlen=max_size) + self.last_send_time = 0 + + def add(self, message: Message) -> None: + """按时间顺序添加消息到队列""" + if not self.messages: + self.messages.append(message) + return + + # 按时间顺序插入 + if message.time >= self.messages[-1].time: + self.messages.append(message) + return + + # 使用二分查找找到合适的插入位置 + messages_list = list(self.messages) + left, right = 0, len(messages_list) + + while left < right: + mid = (left + right) // 2 + if messages_list[mid].time < message.time: + left = mid + 1 + else: + right = mid + + # 重建消息队列,保持时间顺序 + new_messages = deque(maxlen=self.max_size) + new_messages.extend(messages_list[:left]) + new_messages.append(message) + new_messages.extend(messages_list[left:]) + self.messages = new_messages + def get_earliest_message(self) -> Optional[Message]: + """获取时间最早的消息""" + message = self.messages.popleft() if self.messages else None + # 如果是思考中的消息且思考时间不够,重新加入队列 + # if (isinstance(message, Message_Thinking) and + # time.time() - message.start_time < 2): # 最少思考2秒 + # self.messages.appendleft(message) + # return None + return message + + def clear(self) -> None: + """清空队列""" + self.messages.clear() + + def get_all(self, group_id: Optional[int] = None) -> List[Union[Message, Message_Thinking]]: + """获取所有待发送的消息""" + if group_id is None: + return list(self.messages) + return [msg for msg in self.messages if msg.group_id == group_id] + + def peek_next(self) -> Optional[Union[Message, Message_Thinking]]: + """查看下一条要发送的消息(不移除)""" + return self.messages[0] if self.messages else None + + def has_messages(self) -> bool: + """检查是否有待发送的消息""" + return bool(self.messages) + + def count(self, group_id: Optional[int] = None) -> int: + """获取待发送消息数量""" + if group_id is None: + return len(self.messages) + return len([msg for msg in self.messages if msg.group_id == group_id]) + + def get_last_send_time(self) -> float: + """获取最后一次发送时间""" + return self.last_send_time + + def update_send_time(self): + """更新最后发送时间""" + self.last_send_time = time.time() + +class SendTempContainer: + """管理所有群组的消息缓存容器""" + def __init__(self): + self.temp_queues: Dict[int, SendTemp] = {} + + def get_queue(self, group_id: int) -> SendTemp: + """获取或创建群组的消息队列""" + if group_id not in self.temp_queues: + self.temp_queues[group_id] = SendTemp(group_id) + return self.temp_queues[group_id] + + def add_message(self, message: Message) -> None: + """添加消息到对应群组的队列""" + queue = self.get_queue(message.group_id) + queue.add(message) + + def get_group_messages(self, group_id: int) -> List[Union[Message, Message_Thinking]]: + """获取指定群组的所有待发送消息""" + queue = self.get_queue(group_id) + return queue.get_all() + + def has_messages(self, group_id: int) -> bool: + """检查指定群组是否有待发送消息""" + queue = self.get_queue(group_id) + return queue.has_messages() + + def get_all_groups(self) -> List[int]: + """获取所有有待发送消息的群组ID""" + return list(self.temp_queues.keys()) + + def update_thinking_message(self, message_obj: Union[Message, MessageSet]) -> bool: + """更新思考中的消息 + + Args: + message_obj: 要更新的消息对象,可以是单条消息或消息组 + + Returns: + bool: 更新是否成功 + """ + queue = self.get_queue(message_obj.group_id) + + # 使用列表解析找到匹配的消息索引 + matching_indices = [ + i for i, msg in enumerate(queue.messages) + if msg.message_id == message_obj.message_id + ] + + if not matching_indices: + return False + + index = matching_indices[0] # 获取第一个匹配的索引 + + # 将消息转换为列表以便修改 + messages = list(queue.messages) + + # 根据消息类型处理 + if isinstance(message_obj, MessageSet): + messages.pop(index) + # 在原位置插入新消息组 + for i, single_message in enumerate(message_obj.messages): + messages.insert(index + i, single_message) + # print(f"\033[1;34m[调试]\033[0m 添加消息组中的第{i+1}条消息: {single_message}") + else: + # 直接替换原消息 + messages[index] = message_obj + # print(f"\033[1;34m[调试]\033[0m 已更新消息: {message_obj}") + + # 重建队列 + queue.messages.clear() + for msg in messages: + queue.messages.append(msg) + + return True + + +class MessageSendControl: + """消息发送控制器""" + def __init__(self): + self.typing_speed = (0.1, 0.3) # 每个字符的打字时间范围(秒) + self.message_interval = (0.5, 1) # 多条消息间的间隔时间范围(秒) + self.max_retry = 3 # 最大重试次数 + self.send_temp_container = SendTempContainer() + self._running = True + self._paused = False + self._current_bot = None + self.storage = MessageStorage() # 添加存储实例 + + def set_bot(self, bot: Bot): + """设置当前bot实例""" + self._current_bot = bot + + async def start_processor(self, bot: Bot): + """启动消息处理器""" + self._current_bot = bot + + while self._running: + await asyncio.sleep(0.5) + # 处理所有群组的消息队列 + for group_id in self.send_temp_container.get_all_groups(): + queue = self.send_temp_container.get_queue(group_id) + if queue.has_messages(): + message = queue.peek_next() + # print(f"\033[1;34m[调试]\033[0m 查看最早的消息: {message}") + if message: + if isinstance(message, Message_Thinking): + # 如果是思考中的消息,检查是否需要继续等待 + # message.update_thinking_time() + thinking_time = time.time() - message.time + if thinking_time < 60: # 最少思考2秒 + if int(thinking_time) % 10 == 0: + print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{thinking_time:.1f}秒") + continue + else: + print(f"\033[1;34m[调试]\033[0m 思考消息超时,移除") + queue.get_earliest_message() # 移除超时的思考消息 + + elif isinstance(message, Message): + message = queue.get_earliest_message() + if message and message.processed_plain_text: + print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}") + + cost_time = round(time.time(), 2) - message.time + # print(f"\033[1;34m[调试]\033[0m 消息发送111111时间: {cost_time}秒") + if cost_time > 40: + message.processed_plain_text = CQCode.create_reply_cq(message.message_based_id) + message.processed_plain_text + + + + + await self._current_bot.send_group_msg( + group_id=group_id, + message=str(message.processed_plain_text), + auto_escape=False + ) + + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) + print(f"\033[1;32m群 {group_id} 消息, 用户 麦麦, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}") + await self.storage.store_message(message, None) + + queue.update_send_time() + + if queue.has_messages(): + await asyncio.sleep( + random.uniform( + self.message_interval[0], + self.message_interval[1] + ) + ) + + + async def process_group_queue(self, bot: Bot, group_id: int) -> None: + """处理指定群组的消息队列""" + queue = self.send_temp_container.get_queue(group_id) + while queue.has_messages(): + message = queue.get_earliest_message() + if message and message.processed_plain_text: + await self.send_message( + bot=bot, + group_id=group_id, + content=message.processed_plain_text + ) + queue.update_send_time() + + if queue.has_messages(): + await asyncio.sleep( + random.uniform(self.message_interval[0], self.message_interval[1]) + ) + + async def process_all_queues(self, bot: Bot) -> None: + """处理所有群组的消息队列""" + if not self._running or self._paused: + return + + for group_id in self.send_temp_container.get_all_groups(): + await self.process_group_queue(bot, group_id) + + async def send_temp_message(self, + bot: Bot, + group_id: int, + message: Union[Message, Message_Thinking], + with_emoji: bool = False, + emoji_path: Optional[str] = None) -> bool: + """ + 发送单个临时消息 + Args: + bot: Bot实例 + group_id: 群组ID + message: Message对象 + with_emoji: 是否带表情 + emoji_path: 表情图片路径 + Returns: + bool: 发送是否成功 + """ + try: + if with_emoji and emoji_path: + return await self.send_with_emoji( + bot=bot, + group_id=group_id, + text_content=message.processed_plain_text, + emoji_path=emoji_path + ) + else: + return await self.send_message( + bot=bot, + group_id=group_id, + content=message.processed_plain_text + ) + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 发送临时消息失败: {str(e)}") + return False + + def set_typing_speed(self, min_speed: float, max_speed: float): + """设置打字速度范围""" + self.typing_speed = (min_speed, max_speed) + + def set_message_interval(self, min_interval: float, max_interval: float): + """设置消息间隔范围""" + self.message_interval = (min_interval, max_interval) + + def pause(self): + """暂停消息处理""" + self._paused = True + + def resume(self): + """恢复消息处理""" + self._paused = False + + def stop(self): + """停止消息处理""" + self._running = False + +# 创建全局实例 +message_sender = MessageSendControl() diff --git a/src/plugins/chat/message_stream.py b/src/plugins/chat/message_stream.py new file mode 100644 index 000000000..23a8b7b9d --- /dev/null +++ b/src/plugins/chat/message_stream.py @@ -0,0 +1,264 @@ +from typing import List, Optional, Dict +from .message import Message +import time +from collections import deque +from datetime import datetime, timedelta +import os +import json +import asyncio + +class MessageStream: + """单个群组的消息流容器""" + def __init__(self, group_id: int, max_size: int = 1000): + self.group_id = group_id + self.messages = deque(maxlen=max_size) + self.max_size = max_size + self.last_save_time = time.time() + + # 确保日志目录存在 + self.log_dir = os.path.join("log", str(self.group_id)) + os.makedirs(self.log_dir, exist_ok=True) + + # 启动自动保存任务 + asyncio.create_task(self._auto_save()) + + async def _auto_save(self): + """每30秒自动保存一次消息记录""" + while True: + await asyncio.sleep(30) # 等待30秒 + await self.save_to_log() + + async def save_to_log(self): + """将消息保存到日志文件""" + try: + current_time = time.time() + # 只有有新消息时才保存 + if not self.messages or self.last_save_time == current_time: + return + + # 生成日志文件名 (使用当前日期) + date_str = time.strftime("%Y-%m-%d", time.localtime(current_time)) + log_file = os.path.join(self.log_dir, f"chat_{date_str}.log") + + # 获取需要保存的新消息 + new_messages = [ + msg for msg in self.messages + if msg.time > self.last_save_time + ] + + if not new_messages: + return + + # 将消息转换为可序列化的格式 + message_logs = [] + for msg in new_messages: + message_logs.append({ + "time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(msg.time)), + "user_id": msg.user_id, + "user_nickname": msg.user_nickname, + "message_id": msg.message_id, + "raw_message": msg.raw_message, + "processed_text": msg.processed_plain_text + }) + + # 追加写入日志文件 + with open(log_file, "a", encoding="utf-8") as f: + for log in message_logs: + f.write(json.dumps(log, ensure_ascii=False) + "\n") + + self.last_save_time = current_time + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 保存群 {self.group_id} 的消息日志失败: {str(e)}") + + def add_message(self, message: Message) -> None: + """按时间顺序添加新消息到队列 + + 使用改进的二分查找算法来保持消息的时间顺序,同时优化内存使用。 + + Args: + message: Message对象,要添加的新消息 + """ + + # 空队列或消息应该添加到末尾的情况 + if (not self.messages or + message.time >= self.messages[-1].time): + self.messages.append(message) + return + + # 消息应该添加到开头的情况 + if message.time <= self.messages[0].time: + self.messages.appendleft(message) + return + + # 使用二分查找在现有队列中找到合适的插入位置 + left, right = 0, len(self.messages) - 1 + while left <= right: + mid = (left + right) // 2 + if self.messages[mid].time < message.time: + left = mid + 1 + else: + right = mid - 1 + + temp = list(self.messages) + temp.insert(left, message) + + # 如果超出最大长度,移除多余的消息 + if len(temp) > self.max_size: + temp = temp[-self.max_size:] + + # 重建队列 + self.messages = deque(temp, maxlen=self.max_size) + + async def get_recent_messages_from_db(self, count: int = 10) -> List[Message]: + """从数据库中获取最近的消息记录 + + Args: + count: 需要获取的消息数量 + + Returns: + List[Message]: 最近的消息列表 + """ + try: + from ...common.database import Database + db = Database.get_instance() + + # 从数据库中查询最近的消息 + recent_messages = list(db.db.messages.find( + {"group_id": self.group_id}, + { + "time": 1, + "user_id": 1, + "user_nickname": 1, + "message_id": 1, + "raw_message": 1, + "processed_text": 1 + } + ).sort("time", -1).limit(count)) + + if not recent_messages: + return [] + + # 转换为 Message 对象 + from .message import Message + messages = [] + for msg_data in recent_messages: + msg = Message( + time=msg_data["time"], + user_id=msg_data["user_id"], + user_nickname=msg_data.get("user_nickname", ""), + message_id=msg_data["message_id"], + raw_message=msg_data["raw_message"], + processed_plain_text=msg_data.get("processed_text", ""), + group_id=self.group_id + ) + messages.append(msg) + + return list(reversed(messages)) # 返回按时间正序的消息 + + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 从数据库获取群 {self.group_id} 的最近消息记录失败: {str(e)}") + return [] + + def get_recent_messages(self, count: int = 10) -> List[Message]: + """获取最近的n条消息(从内存队列)""" + print(f"\033[1;34m[调试]\033[0m 从内存获取群 {self.group_id} 的最近{count}条消息记录") + return list(self.messages)[-count:] + + def get_messages_in_timerange(self, + start_time: Optional[float] = None, + end_time: Optional[float] = None) -> List[Message]: + """获取时间范围内的消息""" + if start_time is None: + start_time = time.time() - 3600 + if end_time is None: + end_time = time.time() + + return [ + msg for msg in self.messages + if start_time <= msg.time <= end_time + ] + + def get_user_messages(self, user_id: int, count: int = 10) -> List[Message]: + """获取特定用户的最近消息""" + user_messages = [msg for msg in self.messages if msg.user_id == user_id] + return user_messages[-count:] + + def clear_old_messages(self, hours: int = 24) -> None: + """清理旧消息""" + cutoff_time = time.time() - (hours * 3600) + self.messages = deque( + [msg for msg in self.messages if msg.time > cutoff_time], + maxlen=self.max_size + ) + +class MessageStreamContainer: + """管理所有群组的消息流容器""" + def __init__(self, max_size: int = 1000): + self.streams: Dict[int, MessageStream] = {} + self.max_size = max_size + + async def save_all_logs(self): + """保存所有群组的消息日志""" + for stream in self.streams.values(): + await stream.save_to_log() + + def add_message(self, message: Message) -> None: + """添加消息到对应群组的消息流""" + if not message.group_id: + return + + if message.group_id not in self.streams: + self.streams[message.group_id] = MessageStream(message.group_id, self.max_size) + + self.streams[message.group_id].add_message(message) + + def get_stream(self, group_id: int) -> Optional[MessageStream]: + """获取特定群组的消息流""" + return self.streams.get(group_id) + + def get_all_streams(self) -> Dict[int, MessageStream]: + """获取所有群组的消息流""" + return self.streams + + def clear_old_messages(self, hours: int = 24) -> None: + """清理所有群组的旧消息""" + for stream in self.streams.values(): + stream.clear_old_messages(hours) + + def get_group_stats(self, group_id: int) -> Dict: + """获取群组的消息统计信息""" + stream = self.streams.get(group_id) + if not stream: + return { + "total_messages": 0, + "unique_users": 0, + "active_hours": [], + "most_active_user": None + } + + messages = stream.messages + user_counts = {} + hour_counts = {} + + for msg in messages: + user_counts[msg.user_id] = user_counts.get(msg.user_id, 0) + 1 + hour = datetime.fromtimestamp(msg.time).hour + hour_counts[hour] = hour_counts.get(hour, 0) + 1 + + most_active_user = max(user_counts.items(), key=lambda x: x[1])[0] if user_counts else None + active_hours = sorted( + hour_counts.items(), + key=lambda x: x[1], + reverse=True + )[:5] + + return { + "total_messages": len(messages), + "unique_users": len(user_counts), + "active_hours": active_hours, + "most_active_user": most_active_user + } + +# 创建全局实例 +message_stream_container = MessageStreamContainer() diff --git a/src/plugins/chat/message_visualizer.py b/src/plugins/chat/message_visualizer.py new file mode 100644 index 000000000..2dd3f98e7 --- /dev/null +++ b/src/plugins/chat/message_visualizer.py @@ -0,0 +1,138 @@ +import subprocess +import threading +import queue +import os +import time +from typing import Dict +from .message import Message_Thinking + +class MessageVisualizer: + def __init__(self): + self.process = None + self.message_queue = queue.Queue() + self.is_running = False + self.content_file = "message_queue_content.txt" + + def start(self): + if self.process is None: + # 创建用于显示的批处理文件 + with open("message_queue_window.bat", "w", encoding="utf-8") as f: + f.write('@echo off\n') + f.write('chcp 65001\n') # 设置UTF-8编码 + f.write('title Message Queue Visualizer\n') + f.write('echo Waiting for message queue updates...\n') + f.write(':loop\n') + f.write('if exist "queue_update.txt" (\n') + f.write(' type "queue_update.txt" > "message_queue_content.txt"\n') + f.write(' del "queue_update.txt"\n') + f.write(' cls\n') + f.write(' type "message_queue_content.txt"\n') + f.write(')\n') + f.write('timeout /t 1 /nobreak >nul\n') + f.write('goto loop\n') + + # 清空内容文件 + with open(self.content_file, "w", encoding="utf-8") as f: + f.write("") + + # 启动新窗口 + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + self.process = subprocess.Popen( + ['cmd', '/c', 'start', 'message_queue_window.bat'], + shell=True, + startupinfo=startupinfo + ) + self.is_running = True + + # 启动处理线程 + threading.Thread(target=self._process_messages, daemon=True).start() + + def _process_messages(self): + while self.is_running: + try: + # 获取新消息 + text = self.message_queue.get(timeout=1) + # 写入更新文件 + with open("queue_update.txt", "w", encoding="utf-8") as f: + f.write(text) + except queue.Empty: + continue + except Exception as e: + print(f"处理队列可视化内容时出错: {e}") + + def update_content(self, send_temp_container): + """更新显示内容""" + if not self.is_running: + return + + current_time = time.strftime("%Y-%m-%d %H:%M:%S") + display_text = f"Message Queue Status - {current_time}\n" + display_text += "=" * 50 + "\n\n" + + # 遍历所有群组的队列 + for group_id, queue in send_temp_container.temp_queues.items(): + display_text += f"\n{'='*20} 群组: {queue.group_id} {'='*20}\n" + display_text += f"消息队列长度: {len(queue.messages)}\n" + display_text += f"最后发送时间: {time.strftime('%H:%M:%S', time.localtime(queue.last_send_time))}\n" + display_text += "\n消息队列内容:\n" + + # 显示队列中的消息 + if not queue.messages: + display_text += " [空队列]\n" + else: + for i, msg in enumerate(queue.messages): + msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time)) + display_text += f"\n--- 消息 {i+1} ---\n" + + if isinstance(msg, Message_Thinking): + display_text += f"类型: \033[1;33m思考中消息\033[0m\n" + display_text += f"时间: {msg_time}\n" + display_text += f"消息ID: {msg.message_id}\n" + display_text += f"群组: {msg.group_id}\n" + display_text += f"用户: {msg.user_nickname}({msg.user_id})\n" + display_text += f"内容: {msg.thinking_text}\n" + display_text += f"思考时间: {msg.thinking_time}秒\n" + else: + display_text += f"类型: 普通消息\n" + display_text += f"时间: {msg_time}\n" + display_text += f"消息ID: {msg.message_id}\n" + display_text += f"群组: {msg.group_id}\n" + display_text += f"用户: {msg.user_nickname}({msg.user_id})\n" + if hasattr(msg, 'is_emoji') and msg.is_emoji: + display_text += f"内容: [表情包消息]\n" + else: + # 显示原始消息和处理后的消息 + display_text += f"原始内容: {msg.raw_message[:50]}...\n" + display_text += f"处理后内容: {msg.processed_plain_text[:50]}...\n" + + if msg.reply_message: + display_text += f"回复消息: {str(msg.reply_message)[:50]}...\n" + + display_text += f"\n{'-' * 50}\n" + + # 添加统计信息 + display_text += "\n总体统计:\n" + display_text += f"活跃群组数: {len(send_temp_container.temp_queues)}\n" + total_messages = sum(len(q.messages) for q in send_temp_container.temp_queues.values()) + display_text += f"总消息数: {total_messages}\n" + thinking_messages = sum( + sum(1 for msg in q.messages if isinstance(msg, Message_Thinking)) + for q in send_temp_container.temp_queues.values() + ) + display_text += f"思考中消息数: {thinking_messages}\n" + + self.message_queue.put(display_text) + + def stop(self): + self.is_running = False + if self.process: + self.process.terminate() + self.process = None + # 清理文件 + for file in ["message_queue_window.bat", "message_queue_content.txt", "queue_update.txt"]: + if os.path.exists(file): + os.remove(file) + +# 创建全局单例 +message_visualizer = MessageVisualizer() diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py new file mode 100644 index 000000000..1c6ecfc8b --- /dev/null +++ b/src/plugins/chat/prompt_builder.py @@ -0,0 +1,193 @@ +import time +import random +from dotenv import load_dotenv +from ..schedule.schedule_generator import bot_schedule +import os +from .utils import get_embedding, combine_messages, get_recent_group_messages +from ...common.database import Database + +# 获取当前文件的绝对路径 +current_dir = os.path.dirname(os.path.abspath(__file__)) +root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) +load_dotenv(os.path.join(root_dir, '.env')) + + +class PromptBuilder: + def __init__(self): + self.prompt_built = '' + self.activate_messages = '' + self.db = Database.get_instance() + + def _build_prompt(self, + message_txt: str, + sender_name: str = "某人", + relationship_value: float = 0.0, + group_id: int = None) -> str: + """构建prompt + + Args: + message_txt: 消息文本 + sender_name: 发送者昵称 + relationship_value: 关系值 + group_id: 群组ID + + Returns: + str: 构建好的prompt + """ + #先禁用关系 + if 0 > 30: + relation_prompt = "关系特别特别好,你很喜欢喜欢他" + relation_prompt_2 = "热情发言或者回复" + elif 0 <-20: + relation_prompt = "关系很差,你很讨厌他" + relation_prompt_2 = "骂他" + else: + relation_prompt = "关系一般" + relation_prompt_2 = "发言或者回复" + + #开始构建prompt + + #日程构建 + current_date = time.strftime("%Y-%m-%d", time.localtime()) + current_time = time.strftime("%H:%M:%S", time.localtime()) + bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task() + prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n''' + + #知识构建 + prompt_info = '' + prompt_info = self.get_prompt_info(message_txt) + if prompt_info: + prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]: + \n{prompt_info}\n + 请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n''' + + #获取最新消息并计时 + chat_talking_prompt = '' + + if group_id: + # 从数据库获取最近消息 + message_objects = get_recent_group_messages(self.db, group_id) + if message_objects: + chat_talking_prompt = combine_messages(message_objects) + # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的最近{len(message_objects)}条消息记录") + + #激活prompt构建 + activate_prompt = '' + activate_prompt = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2} \n" + + #人格选择 + prompt_personality = '' + personality_choice = random.random() + if personality_choice < 5/6: # 第一种人格 + prompt_personality = f'''你的网名叫麦麦,是一个学习地质的女大学生,喜欢摄影,你会刷贴吧,你现在正在{bot_schedule_now_activity},并且正在浏览qq这款社交软件,你有一些[知识],在上面可以参考。 + 现在请你给出日常且口语化的回复,不要过多提及自身的背景,平淡一些,尽量简短一些,不要输出感叹号。如果有人说你是人机或者机器人,否定并攻击那个人。 + 不要回复的太有逻辑或者太有条理。''' + elif personality_choice < 1: # 第二种人格 + prompt_personality = f'''你的网名叫麦麦,是一个女大学生,你有黑色头发你会刷微博和小红书,你现在正在{bot_schedule_now_activity},并且正在浏览qq群,你有一些[知识],在上面可以参考。 + 现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些,不要输出感叹号。如果有人说你是人机或者机器人,否定。 + 如果有人说他是机器人,肯定他。如果有人对你谈论涉及性的话题或言语挑逗,需要反驳,但其它情况请你表达自己的见解和观点。''' + + #额外信息要求 + extra_info = '''回复平淡一些,简短一些,但是记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' + + #合并prompt + prompt = "" + prompt += f"{prompt_info}\n" + prompt += f"{prompt_date}\n" + prompt += f"{chat_talking_prompt}\n" + prompt += f"{activate_prompt}\n" + prompt += f"{prompt_personality}\n" + prompt += f"{extra_info}\n" + + return prompt + + def get_prompt_info(self,message:str): + related_info = '' + if len(message) > 10: + message_segments = [message[i:i+10] for i in range(0, len(message), 10)] + for segment in message_segments: + embedding = get_embedding(segment) + related_info += self.get_info_from_db(embedding) + + else: + embedding = get_embedding(message) + related_info += self.get_info_from_db(embedding) + + def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str: + """ + 从知识库中查找与输入向量最相似的内容 + Args: + query_embedding: 查询向量 + limit: 返回结果数量,默认为2 + threshold: 相似度阈值,默认为0.5 + Returns: + str: 找到的相关信息,如果相似度低于阈值则返回空字符串 + """ + if not query_embedding: + return '' + + # 使用余弦相似度计算 + pipeline = [ + { + "$addFields": { + "dotProduct": { + "$reduce": { + "input": {"$range": [0, {"$size": "$embedding"}]}, + "initialValue": 0, + "in": { + "$add": [ + "$$value", + {"$multiply": [ + {"$arrayElemAt": ["$embedding", "$$this"]}, + {"$arrayElemAt": [query_embedding, "$$this"]} + ]} + ] + } + } + }, + "magnitude1": { + "$sqrt": { + "$reduce": { + "input": "$embedding", + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} + } + } + }, + "magnitude2": { + "$sqrt": { + "$reduce": { + "input": query_embedding, + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} + } + } + } + } + }, + { + "$addFields": { + "similarity": { + "$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}] + } + } + }, + { + "$match": { + "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果 + } + }, + {"$sort": {"similarity": -1}}, + {"$limit": limit}, + {"$project": {"content": 1, "similarity": 1}} + ] + + results = list(self.db.db.knowledges.aggregate(pipeline)) + + if not results: + return '' + + # 返回所有找到的内容,用换行分隔 + return '\n'.join(str(result['content']) for result in results) + +prompt_builder = PromptBuilder() \ No newline at end of file diff --git a/src/plugins/chat/relationship_manager.py b/src/plugins/chat/relationship_manager.py new file mode 100644 index 000000000..e6f355cd4 --- /dev/null +++ b/src/plugins/chat/relationship_manager.py @@ -0,0 +1,200 @@ +import time +from ...common.database import Database +from nonebot.adapters.onebot.v11 import Bot +from typing import Optional, Tuple +import asyncio + +class Impression: + traits: str = None + called: str = None + know_time: float = None + + relationship_value: float = None + +class Relationship: + user_id: int = None + # impression: Impression = None + # group_id: int = None + # group_name: str = None + gender: str = None + age: int = None + nickname: str = None + relationship_value: float = None + saved = False + + def __init__(self, user_id: int, data=None, **kwargs): + if isinstance(data, dict): + # 如果输入是字典,使用字典解析 + self.user_id = data.get('user_id') + self.gender = data.get('gender') + self.age = data.get('age') + self.nickname = data.get('nickname') + self.relationship_value = data.get('relationship_value', 0.0) + self.saved = data.get('saved', False) + else: + # 如果是直接传入属性值 + self.user_id = kwargs.get('user_id') + self.gender = kwargs.get('gender') + self.age = kwargs.get('age') + self.nickname = kwargs.get('nickname') + self.relationship_value = kwargs.get('relationship_value', 0.0) + self.saved = kwargs.get('saved', False) + + + + +class RelationshipManager: + def __init__(self): + self.relationships: dict[int, Relationship] = {} # user_id -> Relationship + #保存 qq号,现在使用昵称,别称 + self.id_name_nickname_table: dict[str, str, list] = {} # name -> [nickname, nickname, ...] + + async def update_relationship(self, user_id: int, data=None, **kwargs): + # 检查是否在内存中已存在 + relationship = self.relationships.get(user_id) + if relationship: + # 如果存在,更新现有对象 + if isinstance(data, dict): + for key, value in data.items(): + if hasattr(relationship, key) and value is not None: + setattr(relationship, key, value) + else: + for key, value in kwargs.items(): + if hasattr(relationship, key) and value is not None: + setattr(relationship, key, value) + else: + # 如果不存在,创建新对象 + relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id, **kwargs) + self.relationships[user_id] = relationship + + # 保存到数据库 + await self.storage_relationship(relationship) + relationship.saved = True + + return relationship + + async def update_relationship_value(self, user_id: int, **kwargs): + # 检查是否在内存中已存在 + relationship = self.relationships.get(user_id) + if relationship: + for key, value in kwargs.items(): + if key == 'relationship_value': + relationship.relationship_value += value + await self.storage_relationship(relationship) + relationship.saved = True + return relationship + else: + print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id} 不存在,无法更新") + return None + + + def get_relationship(self, user_id: int) -> Optional[Relationship]: + """获取用户关系对象""" + if user_id in self.relationships: + return self.relationships[user_id] + else: + return 0 + + async def load_relationship(self, data: dict) -> Relationship: + """从数据库加载或创建新的关系对象""" + rela = Relationship(user_id=data['user_id'], data=data) + rela.saved = True + return rela + + async def _start_relationship_manager(self): + """每5分钟自动保存一次关系数据""" + db = Database.get_instance() + # 获取所有关系记录 + all_relationships = db.db.relationships.find({}) + # 依次加载每条记录 + for data in all_relationships: + user_id = data['user_id'] + relationship = await self.load_relationship(data) + self.relationships[user_id] = relationship + print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录") + + while True: + print(f"\033[1;32m[关系管理]\033[0m 正在自动保存关系") + await asyncio.sleep(300) # 等待300秒(5分钟) + await self._save_all_relationships() + + async def _save_all_relationships(self): + """将所有关系数据保存到数据库""" + # 保存所有关系数据 + for relationship in self.relationships: + if not relationship.saved: + relationship.saved = True + await self.storage_relationship(relationship) + + async def storage_relationship(self,relationship: Relationship): + """ + 将关系记录存储到数据库中 + """ + user_id = relationship.user_id + nickname = relationship.nickname + relationship_value = relationship.relationship_value + gender = relationship.gender + age = relationship.age + saved = relationship.saved + + db = Database.get_instance() + db.db.relationships.update_one( + {'user_id': user_id}, + {'$set': { + 'nickname': nickname, + 'relationship_value': relationship_value, + 'gender': gender, + 'age': age, + 'saved': saved + }}, + upsert=True + ) + + @staticmethod + async def get_user_nickname(bot: Bot, user_id: int, group_id: int = None) -> Tuple[str, Optional[str]]: + """ + 通过QQ API获取用户昵称 + """ + + # 获取QQ昵称 + stranger_info = await bot.get_stranger_info(user_id=user_id) + qq_nickname = stranger_info['nickname'] + + # 如果提供了群号,获取群昵称 + if group_id: + try: + member_info = await bot.get_group_member_info( + group_id=group_id, + user_id=user_id, + no_cache=True + ) + group_nickname = member_info['card'] or None + return qq_nickname, group_nickname + except: + return qq_nickname, None + + return qq_nickname, None + + def print_all_relationships(self): + """打印内存中所有的关系记录""" + print("\n\033[1;32m[关系管理]\033[0m 当前内存中的所有关系:") + print("=" * 50) + + if not self.relationships: + print("暂无关系记录") + return + + for user_id, relationship in self.relationships.items(): + print(f"用户ID: {user_id}") + print(f"昵称: {relationship.nickname}") + print(f"好感度: {relationship.relationship_value}") + print("-" * 30) + + print("=" * 50) + + + + + + +relationship_manager = RelationshipManager() \ No newline at end of file diff --git a/src/plugins/chat/storage.py b/src/plugins/chat/storage.py new file mode 100644 index 000000000..4de6265c9 --- /dev/null +++ b/src/plugins/chat/storage.py @@ -0,0 +1,48 @@ +from typing import Dict, List, Any, Optional +import time +import threading +from collections import defaultdict +import asyncio +from .message import Message +from ...common.database import Database +from .image_utils import storage_compress_image + +class MessageStorage: + def __init__(self): + self.db = Database.get_instance() + + async def store_message(self, message: Message, topic: Optional[str] = None) -> None: + """存储消息到数据库""" + try: + if not message.is_emoji: + message_data = { + "group_id": message.group_id, + "user_id": message.user_id, + "message_id": message.message_id, + "raw_message": message.raw_message, + "plain_text": message.plain_text, + "processed_plain_text": message.processed_plain_text, + "time": message.time, + "user_nickname": message.user_nickname, + "group_name": message.group_name, + "topic": topic, + } + else: + message_data = { + "group_id": message.group_id, + "user_id": message.user_id, + "message_id": message.message_id, + "raw_message": message.raw_message, + "plain_text": message.plain_text, + "processed_plain_text": '[表情包]', + "time": message.time, + "user_nickname": message.user_nickname, + "group_name": message.group_name, + "topic": topic, + } + + self.db.db.messages.insert_one(message_data) + except Exception as e: + print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}") + +# 如果需要其他存储相关的函数,可以在这里添加 \ No newline at end of file diff --git a/src/plugins/chat/topic_identifier.py b/src/plugins/chat/topic_identifier.py new file mode 100644 index 000000000..75593667b --- /dev/null +++ b/src/plugins/chat/topic_identifier.py @@ -0,0 +1,96 @@ +from typing import Optional, Dict, List +from openai import OpenAI +from .message import Message +from .config import global_config, llm_config +import jieba + +class TopicIdentifier: + def __init__(self): + self.client = OpenAI( + api_key=llm_config.SILICONFLOW_API_KEY, + base_url=llm_config.SILICONFLOW_BASE_URL + ) + + def identify_topic_llm(self, text: str) -> Optional[str]: + """识别消息主题""" + + prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求: +1. 主题通常2-4个字,必须简短,要求精准概括,不要太具体。 +2. 建议给出多个主题,之间用英文逗号分割。只输出主题本身就好,不要有前后缀。 + +消息内容:{text}""" + + response = self.client.chat.completions.create( + model="Pro/deepseek-ai/DeepSeek-V3", + messages=[{"role": "user", "content": prompt}], + temperature=0.8, + max_tokens=10 + ) + + if not response or not response.choices: + print(f"\033[1;31m[错误]\033[0m OpenAI API 返回为空") + return None + + # 从 OpenAI API 响应中获取第一个选项的消息内容,并去除首尾空白字符 + topic = response.choices[0].message.content.strip() if response.choices[0].message.content else None + + if topic == "无主题": + return None + else: + # print(f"[主题分析结果]{text[:20]}... : {topic}") + split_topic = self.parse_topic(topic) + return split_topic + + + def parse_topic(self, topic: str) -> List[str]: + """解析主题,返回主题列表""" + if not topic or topic == "无主题": + return [] + return [t.strip() for t in topic.split(",") if t.strip()] + + def identify_topic_jieba(self, text: str) -> Optional[str]: + """使用jieba识别主题""" + words = jieba.lcut(text) + # 去除停用词和标点符号 + stop_words = { + '的', '了', '和', '是', '就', '都', '而', '及', '与', '这', '那', '但', '然', '却', + '因为', '所以', '如果', '虽然', '一个', '我', '你', '他', '她', '它', '我们', '你们', + '他们', '在', '有', '个', '把', '被', '让', '给', '从', '向', '到', '又', '也', '很', + '啊', '吧', '呢', '吗', '呀', '哦', '哈', '么', '嘛', '啦', '哎', '唉', '哇', '嗯', + '哼', '哪', '什么', '怎么', '为什么', '怎样', '如何', '什么样', '这样', '那样', '这么', + '那么', '多少', '几', '谁', '哪里', '哪儿', '什么时候', '何时', '为何', '怎么办', + '怎么样', '这些', '那些', '一些', '一点', '一下', '一直', '一定', '一般', '一样', + '一会儿', '一边', '一起', + # 添加更多量词 + '个', '只', '条', '张', '片', '块', '本', '册', '页', '幅', '面', '篇', '份', + '朵', '颗', '粒', '座', '幢', '栋', '间', '层', '家', '户', '位', '名', '群', + '双', '对', '打', '副', '套', '批', '组', '串', '包', '箱', '袋', '瓶', '罐', + # 添加更多介词 + '按', '按照', '把', '被', '比', '比如', '除', '除了', '当', '对', '对于', + '根据', '关于', '跟', '和', '将', '经', '经过', '靠', '连', '论', '通过', + '同', '往', '为', '为了', '围绕', '于', '由', '由于', '与', '在', '沿', '沿着', + '依', '依照', '以', '因', '因为', '用', '由', '与', '自', '自从' + } + + # 过滤掉停用词和标点符号,只保留名词和动词 + filtered_words = [] + for word in words: + if word not in stop_words and not word.strip() in { + '。', ',', '、', ':', ';', '!', '?', '"', '"', ''', ''', + '(', ')', '【', '】', '《', '》', '…', '—', '·', '、', '~', + '~', '+', '=', '-' + }: + filtered_words.append(word) + + # 统计词频 + word_freq = {} + for word in filtered_words: + word_freq[word] = word_freq.get(word, 0) + 1 + + # 按词频排序,取前3个 + sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True) + top_words = [word for word, freq in sorted_words[:3]] + + return top_words if top_words else None + +topic_identifier = TopicIdentifier() \ No newline at end of file diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py new file mode 100644 index 000000000..9d3abd0f7 --- /dev/null +++ b/src/plugins/chat/utils.py @@ -0,0 +1,115 @@ +import time +from typing import List +from .message import Message +import requests +import numpy as np +from .config import llm_config + +def combine_messages(messages: List[Message]) -> str: + """将消息列表组合成格式化的字符串 + + Args: + messages: Message对象列表 + + Returns: + str: 格式化后的消息字符串 + """ + result = "" + for message in messages: + time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time)) + name = message.user_nickname or f"用户{message.user_id}" + content = message.processed_plain_text or message.plain_text + + result += f"[{time_str}] {name}: {content}\n" + + return result + +def is_mentioned_bot_in_message(message: Message) -> bool: + """检查消息是否提到了机器人""" + keywords = ['麦麦', '麦哲伦'] + for keyword in keywords: + if keyword in message.processed_plain_text: + return True + return False + +def is_mentioned_bot_in_txt(message: str) -> bool: + """检查消息是否提到了机器人""" + keywords = ['麦麦', '麦哲伦'] + for keyword in keywords: + if keyword in message: + return True + return False + +def get_embedding(text): + url = "https://api.siliconflow.cn/v1/embeddings" + payload = { + "model": "BAAI/bge-m3", + "input": text, + "encoding_format": "float" + } + headers = { + "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}", + "Content-Type": "application/json" + } + + response = requests.request("POST", url, json=payload, headers=headers) + + if response.status_code != 200: + print(f"API请求失败: {response.status_code}") + print(f"错误信息: {response.text}") + return None + + return response.json()['data'][0]['embedding'] + +def cosine_similarity(v1, v2): + dot_product = np.dot(v1, v2) + norm1 = np.linalg.norm(v1) + norm2 = np.linalg.norm(v2) + return dot_product / (norm1 * norm2) + +def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: + """从数据库获取群组最近的消息记录 + + Args: + db: Database实例 + group_id: 群组ID + limit: 获取消息数量,默认12条 + + Returns: + list: Message对象列表,按时间正序排列 + """ + + # 从数据库获取最近消息 + recent_messages = list(db.db.messages.find( + {"group_id": group_id}, + { + "time": 1, + "user_id": 1, + "user_nickname": 1, + "message_id": 1, + "raw_message": 1, + "processed_text": 1 + } + ).sort("time", -1).limit(limit)) + + if not recent_messages: + return [] + + # 转换为 Message对象列表 + from .message import Message + message_objects = [] + for msg_data in recent_messages: + msg = Message( + time=msg_data["time"], + user_id=msg_data["user_id"], + user_nickname=msg_data.get("user_nickname", ""), + message_id=msg_data["message_id"], + raw_message=msg_data["raw_message"], + processed_plain_text=msg_data.get("processed_text", ""), + group_id=group_id + ) + message_objects.append(msg) + + # 按时间正序排列 + message_objects.reverse() + return message_objects diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py new file mode 100644 index 000000000..d9ffd4c89 --- /dev/null +++ b/src/plugins/chat/willing_manager.py @@ -0,0 +1,77 @@ +import asyncio + +class WillingManager: + def __init__(self): + self.group_reply_willing = {} # 存储每个群的回复意愿 + self._decay_task = None + self._started = False + + async def _decay_reply_willing(self): + """定期衰减回复意愿""" + while True: + await asyncio.sleep(3) + for group_id in self.group_reply_willing: + # 每分钟衰减10%的回复意愿 + self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6) + + def get_willing(self, group_id: int) -> float: + """获取指定群组的回复意愿""" + return self.group_reply_willing.get(group_id, 0) + + def set_willing(self, group_id: int, willing: float): + """设置指定群组的回复意愿""" + self.group_reply_willing[group_id] = willing + + def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False) -> float: + """改变指定群组的回复意愿并返回回复概率""" + current_willing = self.group_reply_willing.get(group_id, 0) + + if topic and current_willing < 1: + current_willing += 0.6 + elif topic: + current_willing += 0.05 + + if is_mentioned_bot and current_willing < 1.0: + current_willing += 1 + elif is_mentioned_bot: + current_willing += 0.05 + + if is_emoji: + current_willing *= 0.2 + + self.group_reply_willing[group_id] = min(current_willing, 3.0) + + reply_probability = (current_willing - 0.5) * 2 + if group_id not in config.talk_allowed_groups: + current_willing = 0 + reply_probability = 0 + + if group_id in config.talk_frequency_down_groups: + reply_probability = reply_probability / 2 + + if is_mentioned_bot and user_id == int(1026294844): + reply_probability = 1 + + return reply_probability + + def change_reply_willing_sent(self, group_id: int): + """发送消息后降低群组的回复意愿""" + current_willing = self.group_reply_willing.get(group_id, 0) + self.group_reply_willing[group_id] = max(0, current_willing - 1.8) + + def change_reply_willing_after_sent(self, group_id: int): + """发送消息后提高群组的回复意愿""" + current_willing = self.group_reply_willing.get(group_id, 0) + # 如果当前意愿小于1,增加0.3的意愿值 + if current_willing < 1: + self.group_reply_willing[group_id] = min(1, current_willing + 0.8) + + async def ensure_started(self): + """确保衰减任务已启动""" + if not self._started: + if self._decay_task is None: + self._decay_task = asyncio.create_task(self._decay_reply_willing()) + self._started = True + +# 创建全局实例 +willing_manager = WillingManager() \ No newline at end of file diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py new file mode 100644 index 000000000..72d7dd523 --- /dev/null +++ b/src/plugins/schedule/schedule_generator.py @@ -0,0 +1,156 @@ +import datetime +from typing import List, Dict +from .schedule_llm_module import LLMModel +from ...common.database import Database # 使用正确的导入语法 + + +# import sys +# sys.path.append("C:/GitHub/MegMeg-bot") # 添加项目根目录到 Python 路径 +# from src.plugins.schedule.schedule_llm_module import LLMModel +# from src.common.database import Database # 使用正确的导入语法 + +Database.initialize( + "127.0.0.1", + 27017, + "MegBot" + ) + +class ScheduleGenerator: + def __init__(self): + self.llm_scheduler = LLMModel(model_name="Pro/deepseek-ai/DeepSeek-V3") + self.db = Database.get_instance() + + today = datetime.datetime.now() + tomorrow = datetime.datetime.now() + datetime.timedelta(days=1) + yesterday = datetime.datetime.now() - datetime.timedelta(days=1) + + self.today_schedule_text, self.today_schedule = self.generate_daily_schedule(target_date=today) + + self.tomorrow_schedule_text, self.tomorrow_schedule = self.generate_daily_schedule(target_date=tomorrow,read_only=True) + self.yesterday_schedule_text, self.yesterday_schedule = self.generate_daily_schedule(target_date=yesterday,read_only=True) + + def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]: + if target_date is None: + target_date = datetime.datetime.now() + + date_str = target_date.strftime("%Y-%m-%d") + weekday = target_date.strftime("%A") + + + schedule_text = str + + existing_schedule = self.db.db.schedule.find_one({"date": date_str}) + if existing_schedule: + print(f"{date_str}的日程已存在:") + schedule_text = existing_schedule["schedule"] + # print(self.schedule_text) + + elif read_only == False: + print(f"{date_str}的日程不存在,准备生成新的日程。") + prompt = f"""我是麦麦,一个地质学大二女大学生,喜欢刷qq,贴吧,知乎和小红书,请为我生成{date_str}({weekday})的日程安排,包括: + 1. 早上的学习和工作安排 + 2. 下午的活动和任务 + 3. 晚上的计划和休息时间 + 请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床"。""" + + schedule_text, _ = self.llm_scheduler.generate_response(prompt) + # print(self.schedule_text) + self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) + else: + print(f"{date_str}的日程不存在。") + schedule_text = "忘了" + + return schedule_text,None + + schedule_form = self._parse_schedule(schedule_text) + return schedule_text,schedule_form + + def _parse_schedule(self, schedule_text: str) -> Dict[str, str]: + """解析日程文本,转换为时间和活动的字典""" + schedule_dict = {} + # 按行分割日程文本 + lines = schedule_text.strip().split('\n') + for line in lines: + # print(line) + if ',' in line: + # 假设格式为 "时间: 活动" + time_str, activity = line.split(',', 1) + # print(time_str) + # print(activity) + schedule_dict[time_str.strip()] = activity.strip() + return schedule_dict + + def _parse_time(self, time_str: str) -> str: + """解析时间字符串,转换为时间""" + return datetime.datetime.strptime(time_str, "%H:%M") + + def get_current_task(self) -> str: + """获取当前时间应该进行的任务""" + current_time = datetime.datetime.now().strftime("%H:%M") + + # 找到最接近当前时间的任务 + closest_time = None + min_diff = float('inf') + + # 检查今天的日程 + for time_str in self.today_schedule.keys(): + diff = abs(self._time_diff(current_time, time_str)) + if closest_time is None or diff < min_diff: + closest_time = time_str + min_diff = diff + + # 检查昨天的日程中的晚间任务 + if self.yesterday_schedule: + for time_str in self.yesterday_schedule.keys(): + if time_str >= "20:00": # 只考虑晚上8点之后的任务 + # 计算与昨天这个时间点的差异(需要加24小时) + diff = abs(self._time_diff(current_time, time_str)) + if diff < min_diff: + closest_time = time_str + min_diff = diff + return closest_time, self.yesterday_schedule[closest_time] + + if closest_time: + return closest_time, self.today_schedule[closest_time] + return "摸鱼" + + def _time_diff(self, time1: str, time2: str) -> int: + """计算两个时间字符串之间的分钟差""" + t1 = datetime.datetime.strptime(time1, "%H:%M") + t2 = datetime.datetime.strptime(time2, "%H:%M") + diff = int((t2 - t1).total_seconds() / 60) + # 考虑时间的循环性 + if diff < -720: + diff += 1440 # 加一天的分钟 + elif diff > 720: + diff -= 1440 # 减一天的分钟 + # print(f"时间1[{time1}]: 时间2[{time2}],差值[{diff}]分钟") + return diff + + def print_schedule(self): + """打印完整的日程安排""" + + print("\n=== 今日日程安排 ===") + for time_str, activity in self.today_schedule.items(): + print(f"时间[{time_str}]: 活动[{activity}]") + print("==================\n") + +# def main(): +# # 使用示例 +# scheduler = ScheduleGenerator() +# # new_schedule = scheduler.generate_daily_schedule() +# scheduler.print_schedule() +# print("\n当前任务:") +# print(scheduler.get_current_task()) + +# print("昨天日程:") +# print(scheduler.yesterday_schedule) +# print("今天日程:") +# print(scheduler.today_schedule) +# print("明天日程:") +# print(scheduler.tomorrow_schedule) + +# if __name__ == "__main__": +# main() + +bot_schedule = ScheduleGenerator() \ No newline at end of file diff --git a/src/plugins/schedule/schedule_llm_module.py b/src/plugins/schedule/schedule_llm_module.py new file mode 100644 index 000000000..0f1e71f6c --- /dev/null +++ b/src/plugins/schedule/schedule_llm_module.py @@ -0,0 +1,55 @@ +import os +import requests +from dotenv import load_dotenv +from typing import Tuple, Union + +# 加载环境变量 +load_dotenv() + +class LLMModel: + # def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs): + def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1", **kwargs): + self.model_name = model_name + self.params = kwargs + self.api_key = os.getenv("SILICONFLOW_KEY") + self.base_url = os.getenv("SILICONFLOW_BASE_URL") + + def generate_response(self, prompt: str) -> Tuple[str, str]: + """根据输入的提示生成模型的响应""" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # 构建请求体 + data = { + "model": self.model_name, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.9, + **self.params + } + + # 发送请求到完整的chat/completions端点 + api_url = f"{self.base_url.rstrip('/')}/chat/completions" + + try: + response = requests.post(api_url, headers=headers, json=data) + response.raise_for_status() # 检查响应状态 + + result = response.json() + if "choices" in result and len(result["choices"]) > 0: + content = result["choices"][0]["message"]["content"] + reasoning_content = result["choices"][0]["message"].get("reasoning_content", "") + return content, reasoning_content # 返回内容和推理内容 + return "没有返回结果", "" # 返回两个值 + + except requests.exceptions.RequestException as e: + return f"请求失败: {str(e)}", "" # 返回错误信息和空字符串 + +# 示例用法 +if __name__ == "__main__": + model = LLMModel() # 默认使用 DeepSeek-V3 模型 + prompt = "你好,你喜欢我吗?" + result, reasoning = model.generate_response(prompt) + print("回复内容:", result) + print("推理内容:", reasoning) \ No newline at end of file