v0.1
能跑但是没写部署教程,主题和记忆识别也没写完
This commit is contained in:
1
src/common/__init__.py
Normal file
1
src/common/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# 这个文件可以为空,但必须存在
|
||||
21
src/common/database.py
Normal file
21
src/common/database.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pymongo import MongoClient
|
||||
from typing import Optional
|
||||
|
||||
class Database:
|
||||
_instance: Optional["Database"] = None
|
||||
|
||||
def __init__(self, host: str, port: int, db_name: str):
|
||||
self.client = MongoClient(host, port)
|
||||
self.db = self.client[db_name]
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, host: str, port: int, db_name: str) -> "Database":
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(host, port, db_name)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "Database":
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("Database not initialized")
|
||||
return cls._instance
|
||||
201
src/plugins/chat/.stream.py
Normal file
201
src/plugins/chat/.stream.py
Normal file
@@ -0,0 +1,201 @@
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass
|
||||
import time
|
||||
import threading
|
||||
import asyncio
|
||||
from .message import Message
|
||||
from .storage import MessageStorage
|
||||
from .topic_identifier import TopicIdentifier
|
||||
from ...common.database import Database
|
||||
import random
|
||||
|
||||
@dataclass
|
||||
class Topic:
|
||||
id: str
|
||||
name: str
|
||||
messages: List[Message]
|
||||
created_time: float
|
||||
last_active_time: float
|
||||
message_count: int
|
||||
is_active: bool = True
|
||||
|
||||
class MessageStream:
|
||||
def __init__(self):
|
||||
self.storage = MessageStorage()
|
||||
self.active_topics: Dict[int, List[Topic]] = {} # group_id -> topics
|
||||
self.topic_identifier = TopicIdentifier()
|
||||
self.db = Database.get_instance()
|
||||
self.topic_lock = threading.Lock()
|
||||
|
||||
async def start(self):
|
||||
"""异步初始化"""
|
||||
asyncio.create_task(self._monitor_topics())
|
||||
|
||||
async def _monitor_topics(self):
|
||||
"""定时监控主题状态"""
|
||||
while True:
|
||||
await asyncio.sleep(30)
|
||||
self._print_active_topics()
|
||||
self._check_inactive_topics()
|
||||
self._remove_small_topic()
|
||||
|
||||
def _print_active_topics(self):
|
||||
"""打印当前活跃主题"""
|
||||
print("\n" + "="*50)
|
||||
print("\033[1;36m【当前活跃主题】\033[0m") # 青色
|
||||
for group_id, topics in self.active_topics.items():
|
||||
active_topics = [t for t in topics if t.is_active]
|
||||
if active_topics:
|
||||
print(f"\n\033[1;33m群组 {group_id}:\033[0m") # 黄色
|
||||
for topic in active_topics:
|
||||
print(f"\033[1;32m- {topic.name}\033[0m (消息数: {topic.message_count})") # 绿色
|
||||
|
||||
def _check_inactive_topics(self):
|
||||
"""检查并处理不活跃主题"""
|
||||
current_time = time.time()
|
||||
INACTIVE_TIME = 600 # 60秒内没有新增内容
|
||||
# MAX_MESSAGES_WITHOUT_TOPIC = 5 # 最新5条消息都不是这个主题就归档
|
||||
|
||||
with self.topic_lock:
|
||||
for group_id, topics in self.active_topics.items():
|
||||
|
||||
for topic in topics:
|
||||
if not topic.is_active:
|
||||
continue
|
||||
|
||||
# 检查是否超过不活跃时间
|
||||
time_inactive = current_time - topic.last_active_time
|
||||
if time_inactive > INACTIVE_TIME:
|
||||
# print(f"\033[1;33m[主题超时]\033[0m {topic.name} 已有 {int(time_inactive)} 秒未更新")
|
||||
self._archive_topic(group_id, topic)
|
||||
topic.is_active = False
|
||||
continue
|
||||
|
||||
|
||||
def _archive_topic(self, group_id: int, topic: Topic):
|
||||
"""将主题存档到数据库"""
|
||||
# 查找是否有同名主题
|
||||
existing_topic = self.db.db.archived_topics.find_one({
|
||||
"name": topic.name
|
||||
})
|
||||
|
||||
if existing_topic:
|
||||
# 合并消息列表并去重
|
||||
existing_messages = existing_topic.get("messages", [])
|
||||
new_messages = [
|
||||
{
|
||||
"user_id": msg.user_id,
|
||||
"plain_text": msg.plain_text,
|
||||
"time": msg.time
|
||||
} for msg in topic.messages
|
||||
]
|
||||
|
||||
# 使用集合去重
|
||||
seen_texts = set()
|
||||
unique_messages = []
|
||||
|
||||
# 先处理现有消息
|
||||
for msg in existing_messages:
|
||||
if msg["plain_text"] not in seen_texts:
|
||||
seen_texts.add(msg["plain_text"])
|
||||
unique_messages.append(msg)
|
||||
|
||||
# 再处理新消息
|
||||
for msg in new_messages:
|
||||
if msg["plain_text"] not in seen_texts:
|
||||
seen_texts.add(msg["plain_text"])
|
||||
unique_messages.append(msg)
|
||||
|
||||
# 更新主题信息
|
||||
self.db.db.archived_topics.update_one(
|
||||
{"_id": existing_topic["_id"]},
|
||||
{
|
||||
"$set": {
|
||||
"messages": unique_messages,
|
||||
"message_count": len(unique_messages),
|
||||
"last_active_time": max(existing_topic["last_active_time"], topic.last_active_time),
|
||||
"last_merged_time": time.time()
|
||||
}
|
||||
}
|
||||
)
|
||||
print(f"\033[1;33m[主题合并]\033[0m 主题 {topic.name} 已合并,总消息数: {len(unique_messages)}")
|
||||
|
||||
else:
|
||||
# 存储新主题
|
||||
self.db.db.archived_topics.insert_one({
|
||||
"topic_id": topic.id,
|
||||
"name": topic.name,
|
||||
"messages": [
|
||||
{
|
||||
"user_id": msg.user_id,
|
||||
"plain_text": msg.plain_text,
|
||||
"time": msg.time
|
||||
} for msg in topic.messages
|
||||
],
|
||||
"created_time": topic.created_time,
|
||||
"last_active_time": topic.last_active_time,
|
||||
"message_count": topic.message_count
|
||||
})
|
||||
print(f"\033[1;32m[主题存档]\033[0m {topic.name} (群组: {group_id})")
|
||||
|
||||
async def process_message(self, message: Message,topic:List[str]):
|
||||
"""处理新消息,返回识别出的主题列表"""
|
||||
# 存储消息(包含主题)
|
||||
await self.storage.store_message(message, topic)
|
||||
self._update_topics(message.group_id, topic, message)
|
||||
|
||||
def _update_topics(self, group_id: int, topic_names: List[str], message: Message) -> None:
|
||||
"""更新群组主题"""
|
||||
current_time = time.time()
|
||||
|
||||
# 确保群组存在
|
||||
if group_id not in self.active_topics:
|
||||
self.active_topics[group_id] = []
|
||||
|
||||
# 查找现有主题
|
||||
for topic_name in topic_names:
|
||||
for topic in self.active_topics[group_id]:
|
||||
if topic.name == topic_name:
|
||||
topic.messages.append(message)
|
||||
topic.last_active_time = current_time
|
||||
topic.message_count += 1
|
||||
print(f"\033[1;35m[更新主题]\033[0m {topic_name}") # 绿色
|
||||
break
|
||||
else:
|
||||
# 创建新主题
|
||||
new_topic = Topic(
|
||||
id=f"{group_id}_{int(current_time)}",
|
||||
name=topic_name,
|
||||
messages=[message],
|
||||
created_time=current_time,
|
||||
last_active_time=current_time,
|
||||
message_count=1
|
||||
)
|
||||
self.active_topics[group_id].append(new_topic)
|
||||
|
||||
self._check_inactive_topics()
|
||||
|
||||
def _remove_small_topic(self):
|
||||
"""随机移除一个12小时内没有新增内容的小主题"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
inactive_time = 12 * 3600 # 24小时
|
||||
|
||||
# 获取所有符合条件的主题
|
||||
topics = list(self.db.db.archived_topics.find({
|
||||
"message_count": {"$lt": 3}, # 消息数小于2
|
||||
"last_active_time": {"$lt": current_time - inactive_time}
|
||||
}))
|
||||
|
||||
if not topics:
|
||||
return
|
||||
|
||||
# 随机选择一个主题删除
|
||||
topic_to_remove = random.choice(topics)
|
||||
inactive_hours = (current_time - topic_to_remove.get("last_active_time", 0)) / 3600
|
||||
|
||||
self.db.db.archived_topics.delete_one({"_id": topic_to_remove["_id"]})
|
||||
print(f"\033[1;31m[主题清理]\033[0m 已移除小主题: {topic_to_remove['name']} "
|
||||
f"不活跃时间: {int(inactive_hours)}小时)")
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 移除小主题失败: {str(e)}")
|
||||
82
src/plugins/chat/__init__.py
Normal file
82
src/plugins/chat/__init__.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from nonebot import on_message, on_command, require, get_driver
|
||||
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
|
||||
from nonebot.typing import T_State
|
||||
from ...common.database import Database
|
||||
from .config import global_config
|
||||
import os
|
||||
import asyncio
|
||||
import random
|
||||
from .relationship_manager import relationship_manager
|
||||
from ..schedule.schedule_generator import bot_schedule
|
||||
from .willing_manager import willing_manager
|
||||
|
||||
|
||||
|
||||
# 获取驱动器
|
||||
driver = get_driver()
|
||||
|
||||
Database.initialize(
|
||||
global_config.MONGODB_HOST,
|
||||
global_config.MONGODB_PORT,
|
||||
global_config.DATABASE_NAME
|
||||
)
|
||||
|
||||
print("\033[1;32m[初始化配置和数据库完成]\033[0m")
|
||||
|
||||
|
||||
# 导入其他模块
|
||||
from .bot import ChatBot
|
||||
from .emoji_manager import emoji_manager
|
||||
from .message_send_control import message_sender
|
||||
from .relationship_manager import relationship_manager
|
||||
|
||||
# 初始化表情管理器
|
||||
emoji_manager.initialize()
|
||||
|
||||
print("\033[1;32m正在唤醒麦麦......\033[0m")
|
||||
# 创建机器人实例
|
||||
chat_bot = ChatBot(global_config)
|
||||
|
||||
# 注册消息处理器
|
||||
group_msg = on_message()
|
||||
|
||||
# 创建定时任务
|
||||
scheduler = require("nonebot_plugin_apscheduler").scheduler
|
||||
|
||||
# 启动后台任务
|
||||
@driver.on_startup
|
||||
async def start_background_tasks():
|
||||
"""启动后台任务"""
|
||||
# 只启动表情包管理任务
|
||||
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
||||
|
||||
bot_schedule.print_schedule()
|
||||
|
||||
@driver.on_bot_connect
|
||||
async def _(bot: Bot):
|
||||
"""Bot连接成功时的处理"""
|
||||
print("\033[1;38;5;208m-----------麦麦成功连接!-----------\033[0m")
|
||||
message_sender.set_bot(bot)
|
||||
asyncio.create_task(message_sender.start_processor(bot))
|
||||
await willing_manager.ensure_started()
|
||||
print("\033[1;38;5;208m-----------麦麦消息发送器已启动!-----------\033[0m")
|
||||
|
||||
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
|
||||
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
|
||||
# 启动消息发送控制任务
|
||||
|
||||
@driver.on_startup
|
||||
async def init_relationships():
|
||||
"""在 NoneBot2 启动时初始化关系管理器"""
|
||||
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...")
|
||||
asyncio.create_task(relationship_manager._start_relationship_manager())
|
||||
|
||||
@group_msg.handle()
|
||||
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
|
||||
await chat_bot.handle_message(event, bot)
|
||||
|
||||
@scheduler.scheduled_job("interval", seconds=300000, id="monitor_relationships")
|
||||
async def monitor_relationships():
|
||||
"""每15秒打印一次关系数据"""
|
||||
relationship_manager.print_all_relationships()
|
||||
|
||||
224
src/plugins/chat/bot.py
Normal file
224
src/plugins/chat/bot.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from nonebot.adapters.onebot.v11 import GroupMessageEvent, Message as EventMessage, Bot
|
||||
from .message import Message,MessageSet
|
||||
from .config import BotConfig, global_config
|
||||
from .storage import MessageStorage
|
||||
from .gpt_response import GPTResponseGenerator
|
||||
from .message_stream import MessageStream, MessageStreamContainer
|
||||
from .topic_identifier import topic_identifier
|
||||
from random import random
|
||||
from nonebot.log import logger
|
||||
from .group_info_manager import GroupInfoManager # 导入群信息管理器
|
||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||
import time
|
||||
import os
|
||||
from .cq_code import CQCode # 导入CQCode模块
|
||||
from .message_send_control import message_sender # 导入消息发送控制器
|
||||
from .message import Message_Thinking # 导入 Message_Thinking 类
|
||||
from .relationship_manager import relationship_manager
|
||||
from .prompt_builder import prompt_builder
|
||||
from .willing_manager import willing_manager # 导入意愿管理器
|
||||
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self, config: BotConfig):
|
||||
self.config = config
|
||||
self.storage = MessageStorage()
|
||||
self.gpt = GPTResponseGenerator(config)
|
||||
self.group_info_manager = GroupInfoManager() # 初始化群信息管理器
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
|
||||
self.emoji_chance = 0.2 # 发送表情包的基础概率
|
||||
self.message_streams = MessageStreamContainer()
|
||||
self.message_sender = message_sender
|
||||
|
||||
async def _ensure_started(self):
|
||||
"""确保所有任务已启动"""
|
||||
if not self._started:
|
||||
# 只保留必要的任务
|
||||
self._started = True
|
||||
|
||||
def is_mentioned_bot(self, message: Message) -> bool:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = ['麦麦']
|
||||
for keyword in keywords:
|
||||
if keyword in message.processed_plain_text:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
||||
"""处理收到的群消息"""
|
||||
|
||||
if event.group_id not in self.config.talk_allowed_groups:
|
||||
return
|
||||
self.bot = bot # 更新 bot 实例
|
||||
|
||||
# 打印原始消息内容
|
||||
'''
|
||||
print(f"\n\033[1;33m[消息详情]\033[0m")
|
||||
# print(f"- 原始消息: {str(event.raw_message)}")
|
||||
print(f"- post_type: {event.post_type}")
|
||||
print(f"- sub_type: {event.sub_type}")
|
||||
print(f"- user_id: {event.user_id}")
|
||||
print(f"- message_type: {event.message_type}")
|
||||
# print(f"- message_id: {event.message_id}")
|
||||
# print(f"- message: {event.message}")
|
||||
print(f"- original_message: {event.original_message}")
|
||||
print(f"- raw_message: {event.raw_message}")
|
||||
# print(f"- font: {event.font}")
|
||||
print(f"- sender: {event.sender}")
|
||||
# print(f"- to_me: {event.to_me}")
|
||||
|
||||
if event.reply:
|
||||
print(f"\n\033[1;33m[回复消息详情]\033[0m")
|
||||
# print(f"- message_id: {event.reply.message_id}")
|
||||
print(f"- message_type: {event.reply.message_type}")
|
||||
print(f"- sender: {event.reply.sender}")
|
||||
# print(f"- time: {event.reply.time}")
|
||||
print(f"- message: {event.reply.message}")
|
||||
print(f"- raw_message: {event.reply.raw_message}")
|
||||
# print(f"- original_message: {event.reply.original_message}")
|
||||
'''
|
||||
|
||||
# 获取群组信息,发送消息的用户信息,并对数据库内容做一次更新
|
||||
|
||||
group_info = await bot.get_group_info(group_id=event.group_id)
|
||||
await self.group_info_manager.update_group_info(
|
||||
group_id=event.group_id,
|
||||
group_name=group_info['group_name'],
|
||||
member_count=group_info['member_count']
|
||||
)
|
||||
|
||||
|
||||
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||
|
||||
# print(f"\033[1;32m[关系管理]\033[0m 更新关系: {sender_info}")
|
||||
|
||||
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
|
||||
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
|
||||
print(f"\033[1;32m[关系管理]\033[0m 更新关系值: {relationship_manager.get_relationship(event.user_id).relationship_value}")
|
||||
|
||||
|
||||
|
||||
message = Message(
|
||||
group_id=event.group_id,
|
||||
user_id=event.user_id,
|
||||
message_id=event.message_id,
|
||||
raw_message=str(event.original_message),
|
||||
plain_text=event.get_plaintext(),
|
||||
reply_message=event.reply,
|
||||
)
|
||||
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||
|
||||
topic = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
||||
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic}")
|
||||
|
||||
await self.storage.store_message(message, topic[0] if topic else None)
|
||||
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||
|
||||
print(f"\033[1;34m[调试]\033[0m 当前消息是否是表情包: {message.is_emoji}")
|
||||
|
||||
is_mentioned = self.is_mentioned_bot(message)
|
||||
reply_probability = willing_manager.change_reply_willing_received(
|
||||
event.group_id,
|
||||
topic[0] if topic else None,
|
||||
is_mentioned,
|
||||
self.config,
|
||||
event.user_id,
|
||||
message.is_emoji
|
||||
)
|
||||
current_willing = willing_manager.get_willing(event.group_id)
|
||||
|
||||
|
||||
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability:.1f}]\033[0m")
|
||||
response = ""
|
||||
if random() < reply_probability:
|
||||
|
||||
tinking_time_point = round(time.time(), 2)
|
||||
think_id = 'mt' + str(tinking_time_point)
|
||||
|
||||
thinking_message = Message_Thinking(message=message,message_id=think_id)
|
||||
|
||||
message_sender.send_temp_container.add_message(thinking_message)
|
||||
|
||||
willing_manager.change_reply_willing_sent(thinking_message.group_id)
|
||||
# 生成回复
|
||||
response, emotion = await self.gpt.generate_response(message)
|
||||
|
||||
# 如果生成了回复,发送并记录
|
||||
if response:
|
||||
message_set = MessageSet(event.group_id, self.config.BOT_QQ, think_id)
|
||||
if isinstance(response, list):
|
||||
# 将多条消息合并成一条
|
||||
for msg in response:
|
||||
# print(f"\033[1;34m[调试]\033[0m 载入消息消息: {msg}")
|
||||
# bot_response_time = round(time.time(), 2)
|
||||
timepoint = tinking_time_point-0.3
|
||||
bot_message = Message(
|
||||
group_id=event.group_id,
|
||||
user_id=self.config.BOT_QQ,
|
||||
message_id=think_id,
|
||||
message_based_id=event.message_id,
|
||||
raw_message=msg,
|
||||
plain_text=msg,
|
||||
processed_plain_text=msg,
|
||||
user_nickname="麦麦",
|
||||
group_name=message.group_name,
|
||||
time=timepoint
|
||||
)
|
||||
# print(f"\033[1;34m[调试]\033[0m 添加消息到消息组: {bot_message}")
|
||||
message_set.add_message(bot_message)
|
||||
# print(f"\033[1;34m[调试]\033[0m 输入消息组: {message_set}")
|
||||
message_sender.send_temp_container.update_thinking_message(message_set)
|
||||
else:
|
||||
# bot_response_time = round(time.time(), 2)
|
||||
bot_message = Message(
|
||||
group_id=event.group_id,
|
||||
user_id=self.config.BOT_QQ,
|
||||
message_id=think_id,
|
||||
message_based_id=event.message_id,
|
||||
raw_message=response,
|
||||
plain_text=response,
|
||||
processed_plain_text=response,
|
||||
user_nickname="麦麦",
|
||||
group_name=message.group_name,
|
||||
time=tinking_time_point
|
||||
)
|
||||
# print(f"\033[1;34m[调试]\033[0m 更新单条消息: {bot_message}")
|
||||
message_sender.send_temp_container.update_thinking_message(bot_message)
|
||||
|
||||
|
||||
bot_response_time = tinking_time_point
|
||||
if random() < self.config.emoji_chance:
|
||||
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
|
||||
if emoji_path:
|
||||
emoji_cq = CQCode.create_emoji_cq(emoji_path)
|
||||
|
||||
if random() < 0.5:
|
||||
bot_response_time = tinking_time_point - 1
|
||||
# else:
|
||||
# bot_response_time = bot_response_time + 1
|
||||
|
||||
bot_message = Message(
|
||||
group_id=event.group_id,
|
||||
user_id=self.config.BOT_QQ,
|
||||
message_id=0,
|
||||
raw_message=emoji_cq,
|
||||
plain_text=emoji_cq,
|
||||
processed_plain_text=emoji_cq,
|
||||
user_nickname="麦麦",
|
||||
group_name=message.group_name,
|
||||
time=bot_response_time,
|
||||
is_emoji=True
|
||||
)
|
||||
message_sender.send_temp_container.add_message(bot_message)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 如果收到新消息,提高回复意愿
|
||||
willing_manager.change_reply_willing_after_sent(event.group_id)
|
||||
73
src/plugins/chat/bot_config.toml
Normal file
73
src/plugins/chat/bot_config.toml
Normal file
@@ -0,0 +1,73 @@
|
||||
[database]
|
||||
host = "127.0.0.1"
|
||||
port = 27017
|
||||
name = "MegBot"
|
||||
|
||||
[bot]
|
||||
qq = 2814567326
|
||||
|
||||
[message]
|
||||
min_text_length = 2
|
||||
max_context_size = 15
|
||||
emoji_chance = 0.2
|
||||
|
||||
[emoji]
|
||||
check_interval = 120
|
||||
register_interval = 10
|
||||
|
||||
[response]
|
||||
model_r1_probability = 0.2
|
||||
|
||||
|
||||
[groups]
|
||||
read_allowed = [
|
||||
1030993430, #bot_test_group_1
|
||||
# 1015816696, #m43white
|
||||
739044565, #my_group
|
||||
192194125, #ms
|
||||
591693379, #bot_test_group_2
|
||||
179648561, #nkyy
|
||||
764408046, #daily_news
|
||||
435591861, #m43black
|
||||
851345375, #hjy群
|
||||
708847644, #rotate_cmy
|
||||
534940728, #bh_llh_HYY
|
||||
# 549292720, #mrfz
|
||||
# 231561425, #粉丝群
|
||||
975992476,
|
||||
1140700103,
|
||||
752426484,#nd1
|
||||
115843978,#nd2
|
||||
# 168718420 #bh
|
||||
]
|
||||
|
||||
talk_allowed = [
|
||||
1030993430, #bot_test_group_1
|
||||
# 1015816696, #m43white
|
||||
739044565, #my_group
|
||||
192194125, #ms
|
||||
591693379, #bot_test_group_2
|
||||
179648561, #nkyy
|
||||
764408046, #daily_news
|
||||
#435591861, #m43black
|
||||
851345375, #hjy群
|
||||
708847644, #rotate_cmy
|
||||
534940728, #bh_llh_HYY
|
||||
# 231561425, #粉丝群
|
||||
975992476,
|
||||
1140700103,
|
||||
# 168718420#bh
|
||||
# 752426484,#nd1
|
||||
# 115843978,#nd2
|
||||
]
|
||||
|
||||
talk_frequency_down = [
|
||||
549292720, #mrfz
|
||||
435591861, #m43black
|
||||
# 231561425,
|
||||
975992476,
|
||||
1140700103,
|
||||
534940728
|
||||
# 752426484,#nd1
|
||||
# 115843978,#nd2
|
||||
]
|
||||
109
src/plugins/chat/config.py
Normal file
109
src/plugins/chat/config.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, Optional
|
||||
import os
|
||||
from nonebot.log import logger, default_format
|
||||
import logging
|
||||
import configparser # 添加这行导入
|
||||
import tomli # 添加这行导入
|
||||
|
||||
# 禁用默认的日志输出
|
||||
# logger.remove()
|
||||
|
||||
# # 只禁用 INFO 级别的日志输出到控制台
|
||||
# logging.getLogger('nonebot').handlers.clear()
|
||||
# console_handler = logging.StreamHandler()
|
||||
# console_handler.setLevel(logging.WARNING) # 只输出 WARNING 及以上级别
|
||||
# logging.getLogger('nonebot').addHandler(console_handler)
|
||||
# logging.getLogger('nonebot').setLevel(logging.WARNING)
|
||||
|
||||
@dataclass
|
||||
class BotConfig:
|
||||
"""机器人配置类"""
|
||||
|
||||
# 基础配置
|
||||
MONGODB_HOST: str = "127.0.0.1"
|
||||
MONGODB_PORT: int = 27017
|
||||
DATABASE_NAME: str = "MegBot"
|
||||
|
||||
BOT_QQ: Optional[int] = None
|
||||
|
||||
# 消息处理相关配置
|
||||
MIN_TEXT_LENGTH: int = 2 # 最小处理文本长度
|
||||
MAX_CONTEXT_SIZE: int = 15 # 上下文最大消息数
|
||||
emoji_chance: float = 0.2 # 发送表情包的基础概率
|
||||
|
||||
read_allowed_groups = set()
|
||||
talk_allowed_groups = set()
|
||||
talk_frequency_down_groups = set()
|
||||
|
||||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
||||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||||
|
||||
MODEL_R1_PROBABILITY: float = 0.3 # R1模型概率
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, config_path: str = "bot_config.toml") -> "BotConfig":
|
||||
"""从TOML配置文件加载配置"""
|
||||
config = cls()
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "rb") as f:
|
||||
toml_dict = tomli.load(f)
|
||||
|
||||
# 数据库配置
|
||||
if "database" in toml_dict:
|
||||
db_config = toml_dict["database"]
|
||||
config.MONGODB_HOST = db_config.get("host", config.MONGODB_HOST)
|
||||
config.MONGODB_PORT = db_config.get("port", config.MONGODB_PORT)
|
||||
config.DATABASE_NAME = db_config.get("name", config.DATABASE_NAME)
|
||||
|
||||
if "emoji" in toml_dict:
|
||||
emoji_config = toml_dict["emoji"]
|
||||
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
|
||||
config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL)
|
||||
|
||||
# 机器人基础配置
|
||||
if "bot" in toml_dict:
|
||||
bot_config = toml_dict["bot"]
|
||||
bot_qq = bot_config.get("qq")
|
||||
config.BOT_QQ = int(bot_qq)
|
||||
|
||||
|
||||
if "response" in toml_dict:
|
||||
response_config = toml_dict["response"]
|
||||
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
|
||||
|
||||
# 消息配置
|
||||
if "message" in toml_dict:
|
||||
msg_config = toml_dict["message"]
|
||||
config.MIN_TEXT_LENGTH = msg_config.get("min_text_length", config.MIN_TEXT_LENGTH)
|
||||
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
|
||||
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
|
||||
|
||||
# 群组配置
|
||||
if "groups" in toml_dict:
|
||||
groups_config = toml_dict["groups"]
|
||||
config.read_allowed_groups = set(groups_config.get("read_allowed", []))
|
||||
config.talk_allowed_groups = set(groups_config.get("talk_allowed", []))
|
||||
config.talk_frequency_down_groups = set(groups_config.get("talk_frequency_down", []))
|
||||
|
||||
print(f"\033[1;32m成功加载配置文件: {config_path}\033[0m")
|
||||
|
||||
return config
|
||||
|
||||
global_config = BotConfig.load_config("./src/plugins/chat/bot_config.toml")
|
||||
|
||||
from dotenv import load_dotenv
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||
load_dotenv(os.path.join(root_dir, '.env'))
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""机器人配置类"""
|
||||
# 基础配置
|
||||
SILICONFLOW_API_KEY: str = None
|
||||
SILICONFLOW_BASE_URL: str = None
|
||||
|
||||
llm_config = LLMConfig()
|
||||
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
|
||||
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
|
||||
422
src/plugins/chat/cq_code.py
Normal file
422
src/plugins/chat/cq_code.py
Normal file
@@ -0,0 +1,422 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
import html
|
||||
import requests
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
from .image_utils import storage_compress_image, storage_emoji
|
||||
import os
|
||||
from random import random
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from .config import global_config, llm_config
|
||||
import time
|
||||
import asyncio
|
||||
@dataclass
|
||||
class CQCode:
|
||||
"""
|
||||
CQ码数据类,用于存储和处理CQ码
|
||||
|
||||
属性:
|
||||
type: CQ码类型(如'image', 'at', 'face'等)
|
||||
params: CQ码的参数字典
|
||||
raw_code: 原始CQ码字符串
|
||||
translated_plain_text: 经过处理(如AI翻译)后的文本表示
|
||||
"""
|
||||
type: str
|
||||
params: Dict[str, str]
|
||||
raw_code: str
|
||||
group_id: int
|
||||
user_id: int
|
||||
group_name: str = ""
|
||||
user_nickname: str = ""
|
||||
translated_plain_text: Optional[str] = None
|
||||
reply_message: Dict = None # 存储回复消息
|
||||
|
||||
@classmethod
|
||||
def from_cq_code(cls, cq_code: str, reply: Dict = None) -> 'CQCode':
|
||||
"""
|
||||
从CQ码字符串创建CQCode对象
|
||||
例如:[CQ:image,file=1.jpg,url=http://example.com/1.jpg]
|
||||
"""
|
||||
# 移除前后的[]
|
||||
content = cq_code[1:-1]
|
||||
# 分离类型和参数部分
|
||||
parts = content.split(',')
|
||||
if not parts:
|
||||
return cls('text', {'text': cq_code}, cq_code, group_id=0, user_id=0)
|
||||
|
||||
# 获取CQ类型
|
||||
cq_type = parts[0][3:] # 去掉'CQ:'
|
||||
|
||||
# 解析参数
|
||||
params = {}
|
||||
for part in parts[1:]:
|
||||
if '=' in part:
|
||||
key, value = part.split('=', 1)
|
||||
# 处理转义字符
|
||||
value = cls.unescape(value)
|
||||
params[key] = value
|
||||
|
||||
# 创建实例
|
||||
instance = cls(cq_type, params, cq_code, group_id=0, user_id=0, reply_message=reply)
|
||||
# 根据类型进行相应的翻译处理
|
||||
instance.translate()
|
||||
return instance
|
||||
|
||||
def translate(self):
|
||||
"""根据CQ码类型进行相应的翻译处理"""
|
||||
if self.type == 'text':
|
||||
self.translated_plain_text = self.params.get('text', '')
|
||||
elif self.type == 'image':
|
||||
self.translated_plain_text = self.translate_image()
|
||||
elif self.type == 'at':
|
||||
from .message import Message
|
||||
message_obj = Message(
|
||||
user_id=str(self.params.get('qq', ''))
|
||||
)
|
||||
self.translated_plain_text = f"@{message_obj.user_nickname}"
|
||||
elif self.type == 'reply':
|
||||
self.translated_plain_text = self.translate_reply()
|
||||
elif self.type == 'face':
|
||||
face_id = self.params.get('id', '')
|
||||
# self.translated_plain_text = f"[表情{face_id}]"
|
||||
self.translated_plain_text = f"[表情]"
|
||||
elif self.type == 'forward':
|
||||
self.translated_plain_text = self.translate_forward()
|
||||
else:
|
||||
self.translated_plain_text = f"[{self.type}]"
|
||||
|
||||
def translate_image(self) -> str:
|
||||
"""处理图片类型的CQ码,区分普通图片和表情包"""
|
||||
if 'url' not in self.params:
|
||||
return '[图片]'
|
||||
|
||||
# 获取子类型,默认为普通图片(0)
|
||||
sub_type = int(self.params.get('sub_type', '0'))
|
||||
is_emoji = (sub_type == 1)
|
||||
|
||||
# 添加请求头
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Accept': 'image/webp,image/apng,image/*,*/*;q=0.8',
|
||||
'Accept-Encoding': 'gzip, deflate, br',
|
||||
'Connection': 'keep-alive'
|
||||
}
|
||||
|
||||
# 处理URL编码问题
|
||||
url = html.unescape(self.params['url'])
|
||||
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
raise ValueError(f"无效的URL格式: {url}")
|
||||
|
||||
# 下载图片
|
||||
response = requests.get(url, headers=headers, timeout=10, verify=False)
|
||||
|
||||
if response.status_code == 200:
|
||||
# 检查响应内容类型
|
||||
content_type = response.headers.get('content-type', '')
|
||||
if not content_type.startswith('image/'):
|
||||
raise ValueError(f"响应不是图片类型: {content_type}")
|
||||
|
||||
content = response.content
|
||||
|
||||
image_base64 = base64.b64encode(content).decode('utf-8')
|
||||
|
||||
# 根据子类型选择不同的处理方式
|
||||
if sub_type == 1: # 表情包
|
||||
return self.get_emoji_description(image_base64)
|
||||
elif sub_type == 0: # 普通图片
|
||||
if self.get_image_description_is_setu(image_base64) == "是":
|
||||
print(f"\033[1;34m[调试]\033[0m 哇!涩情图片")
|
||||
# 使用相对路径创建目录
|
||||
# data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "data", "setu")
|
||||
# os.makedirs(data_dir, exist_ok=True)
|
||||
# # 生成随机文件名
|
||||
# file_name = f"{int(time.time())}_{int(random() * 10000)}.jpg"
|
||||
# file_path = os.path.join(data_dir, file_name)
|
||||
# # 将base64解码并保存图片
|
||||
# image_data = base64.b64decode(image_base64)
|
||||
# with open(file_path, "wb") as f:
|
||||
# f.write(image_data)
|
||||
# print(f"\033[1;34m[调试]\033[0m 涩图已保存至: {file_path}")
|
||||
|
||||
return f"[一张涩情图片]"
|
||||
return self.get_image_description(image_base64)
|
||||
else: # 其他类型都按普通图片处理
|
||||
return '[图片]'
|
||||
else:
|
||||
raise ValueError(f"下载图片失败: HTTP状态码 {response.status_code}")
|
||||
|
||||
|
||||
def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""调用AI接口获取表情包描述"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "deepseek-ai/deepseek-vl2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.4
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result_json = response.json()
|
||||
if "choices" in result_json and len(result_json["choices"]) > 0:
|
||||
description = result_json["choices"][0]["message"]["content"]
|
||||
return f"[表情包:{description}]"
|
||||
|
||||
raise ValueError(f"AI接口调用失败: {response.text}")
|
||||
|
||||
def get_image_description(self, image_base64: str) -> str:
|
||||
"""调用AI接口获取普通图片描述"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "deepseek-ai/deepseek-vl2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.6
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result_json = response.json()
|
||||
if "choices" in result_json and len(result_json["choices"]) > 0:
|
||||
description = result_json["choices"][0]["message"]["content"]
|
||||
return f"[图片:{description}]"
|
||||
|
||||
raise ValueError(f"AI接口调用失败: {response.text}")
|
||||
|
||||
|
||||
def get_image_description_is_setu(self, image_base64: str) -> str:
|
||||
"""调用AI接口获取普通图片描述"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "deepseek-ai/deepseek-vl2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "请回答我这张图片是否涉及涩情、情色、裸露或性暗示,请严格判断,有任何涩情迹象就回答是,请用是或否回答"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 300,
|
||||
"temperature": 0.6
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result_json = response.json()
|
||||
if "choices" in result_json and len(result_json["choices"]) > 0:
|
||||
description = result_json["choices"][0]["message"]["content"]
|
||||
# 如果描述中包含"否",返回否,其他情况返回是
|
||||
return "否" if "否" in description else "是"
|
||||
|
||||
raise ValueError(f"AI接口调用失败: {response.text}")
|
||||
|
||||
def translate_forward(self) -> str:
|
||||
"""处理转发消息"""
|
||||
try:
|
||||
if 'content' not in self.params:
|
||||
return '[转发消息]'
|
||||
|
||||
# 解析content内容(需要先反转义)
|
||||
content = self.unescape(self.params['content'])
|
||||
# print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}")
|
||||
# 将字符串形式的列表转换为Python对象
|
||||
import ast
|
||||
try:
|
||||
messages = ast.literal_eval(content)
|
||||
except ValueError as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}")
|
||||
return '[转发消息]'
|
||||
|
||||
# 处理每条消息
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
sender = msg.get('sender', {})
|
||||
nickname = sender.get('card') or sender.get('nickname', '未知用户')
|
||||
|
||||
# 获取消息内容并使用Message类处理
|
||||
raw_message = msg.get('raw_message', '')
|
||||
message_array = msg.get('message', [])
|
||||
|
||||
if message_array and isinstance(message_array, list):
|
||||
# 检查是否包含嵌套的转发消息
|
||||
for message_part in message_array:
|
||||
if message_part.get('type') == 'forward':
|
||||
content = '[转发消息]'
|
||||
break
|
||||
else:
|
||||
# 处理普通消息
|
||||
if raw_message:
|
||||
from .message import Message
|
||||
message_obj = Message(
|
||||
user_id=msg.get('user_id', 0),
|
||||
message_id=msg.get('message_id', 0),
|
||||
raw_message=raw_message,
|
||||
plain_text=raw_message,
|
||||
group_id=msg.get('group_id', 0)
|
||||
)
|
||||
content = message_obj.processed_plain_text
|
||||
else:
|
||||
content = '[空消息]'
|
||||
else:
|
||||
# 处理普通消息
|
||||
if raw_message:
|
||||
from .message import Message
|
||||
message_obj = Message(
|
||||
user_id=msg.get('user_id', 0),
|
||||
message_id=msg.get('message_id', 0),
|
||||
raw_message=raw_message,
|
||||
plain_text=raw_message,
|
||||
group_id=msg.get('group_id', 0)
|
||||
)
|
||||
content = message_obj.processed_plain_text
|
||||
else:
|
||||
content = '[空消息]'
|
||||
|
||||
formatted_msg = f"{nickname}: {content}"
|
||||
formatted_messages.append(formatted_msg)
|
||||
|
||||
# 合并所有消息
|
||||
combined_messages = '\n'.join(formatted_messages)
|
||||
print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}")
|
||||
return f"[转发消息:\n{combined_messages}]"
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}")
|
||||
return '[转发消息]'
|
||||
|
||||
def translate_reply(self) -> str:
|
||||
"""处理回复类型的CQ码"""
|
||||
|
||||
# 创建Message对象
|
||||
from .message import Message
|
||||
if self.reply_message == None:
|
||||
return '[回复某人消息]'
|
||||
|
||||
if self.reply_message.sender.user_id:
|
||||
message_obj = Message(
|
||||
user_id=self.reply_message.sender.user_id,
|
||||
message_id=self.reply_message.message_id,
|
||||
raw_message=str(self.reply_message.message),
|
||||
group_id=self.group_id
|
||||
)
|
||||
if message_obj.user_id == global_config.BOT_QQ:
|
||||
return f"[回复 麦麦 的消息: {message_obj.processed_plain_text}]"
|
||||
else:
|
||||
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
|
||||
|
||||
else:
|
||||
return '[回复某人消息]'
|
||||
|
||||
@staticmethod
|
||||
def unescape(text: str) -> str:
|
||||
"""反转义CQ码中的特殊字符"""
|
||||
return text.replace(',', ',') \
|
||||
.replace('[', '[') \
|
||||
.replace(']', ']') \
|
||||
.replace('&', '&')
|
||||
|
||||
@staticmethod
|
||||
def create_emoji_cq(file_path: str) -> str:
|
||||
"""
|
||||
创建表情包CQ码
|
||||
Args:
|
||||
file_path: 本地表情包文件路径
|
||||
Returns:
|
||||
表情包CQ码字符串
|
||||
"""
|
||||
# 确保使用绝对路径
|
||||
abs_path = os.path.abspath(file_path)
|
||||
# 转义特殊字符
|
||||
escaped_path = abs_path.replace('&', '&') \
|
||||
.replace('[', '[') \
|
||||
.replace(']', ']') \
|
||||
.replace(',', ',')
|
||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
||||
|
||||
@staticmethod
|
||||
def create_reply_cq(message_id: int) -> str:
|
||||
"""
|
||||
创建回复CQ码
|
||||
Args:
|
||||
message_id: 回复的消息ID
|
||||
Returns:
|
||||
回复CQ码字符串
|
||||
"""
|
||||
return f"[CQ:reply,id={message_id}]"
|
||||
414
src/plugins/chat/emoji_manager.py
Normal file
414
src/plugins/chat/emoji_manager.py
Normal file
@@ -0,0 +1,414 @@
|
||||
from typing import List, Dict, Optional
|
||||
import random
|
||||
from ...common.database import Database
|
||||
import os
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import jieba.analyse as jieba_analyse
|
||||
import aiohttp
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
import base64
|
||||
import shutil
|
||||
from .config import global_config, llm_config
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class EmojiManager:
|
||||
_instance = None
|
||||
EMOJI_DIR = "data/emoji" # 表情包存储目录
|
||||
|
||||
EMOTION_KEYWORDS = {
|
||||
'happy': ['开心', '快乐', '高兴', '欢喜', '笑', '喜悦', '兴奋', '愉快', '乐', '好'],
|
||||
'angry': ['生气', '愤怒', '恼火', '不爽', '火大', '怒', '气愤', '恼怒', '发火', '不满'],
|
||||
'sad': ['伤心', '难过', '悲伤', '痛苦', '哭', '忧伤', '悲痛', '哀伤', '委屈', '失落'],
|
||||
'surprised': ['惊讶', '震惊', '吃惊', '意外', '惊', '诧异', '惊奇', '惊喜', '不敢相信', '目瞪口呆'],
|
||||
'disgusted': ['恶心', '讨厌', '厌恶', '反感', '嫌弃', '恶', '嫌恶', '憎恶', '不喜欢', '烦'],
|
||||
'fearful': ['害怕', '恐惧', '惊恐', '担心', '怕', '惊吓', '惊慌', '畏惧', '胆怯', '惧'],
|
||||
'neutral': ['普通', '一般', '还行', '正常', '平静', '平淡', '一般般', '凑合', '还好', '就这样']
|
||||
}
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance.db = None
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self.db = Database.get_instance()
|
||||
self._scan_task = None
|
||||
|
||||
def _ensure_emoji_dir(self):
|
||||
"""确保表情存储目录存在"""
|
||||
os.makedirs(self.EMOJI_DIR, exist_ok=True)
|
||||
|
||||
def initialize(self):
|
||||
"""初始化数据库连接和表情目录"""
|
||||
if not self._initialized:
|
||||
try:
|
||||
self.db = Database.get_instance()
|
||||
self._ensure_emoji_collection()
|
||||
self._ensure_emoji_dir()
|
||||
self._initialized = True
|
||||
# 启动时执行一次完整性检查
|
||||
self.check_emoji_file_integrity()
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 初始化表情管理器失败: {str(e)}")
|
||||
|
||||
def _ensure_db(self):
|
||||
"""确保数据库已初始化"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
if not self._initialized:
|
||||
raise RuntimeError("EmojiManager not initialized")
|
||||
|
||||
def _ensure_emoji_collection(self):
|
||||
"""确保emoji集合存在并创建索引"""
|
||||
if 'emoji' not in self.db.db.list_collection_names():
|
||||
self.db.db.create_collection('emoji')
|
||||
self.db.db.emoji.create_index([('tags', 1)])
|
||||
self.db.db.emoji.create_index([('filename', 1)], unique=True)
|
||||
|
||||
def record_usage(self, emoji_id: str):
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
self.db.db.emoji.update_one(
|
||||
{'_id': emoji_id},
|
||||
{'$inc': {'usage_count': 1}}
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 记录表情使用失败: {str(e)}")
|
||||
|
||||
async def _get_emotion_from_text(self, text: str) -> List[str]:
|
||||
"""从文本中识别情感关键词,使用DeepSeek API进行分析
|
||||
Args:
|
||||
text: 输入文本
|
||||
Returns:
|
||||
List[str]: 匹配到的情感标签列表
|
||||
"""
|
||||
try:
|
||||
# 准备请求数据
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f'分析这段文本:"{text}",从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签。只需要返回标签,不要输出其他任何内容。'
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"temperature": 0.3
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"\033[1;31m[错误]\033[0m API请求失败: {await response.text()}")
|
||||
return ['neutral']
|
||||
|
||||
result = json.loads(await response.text())
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
emotion = result["choices"][0]["message"]["content"].strip().lower()
|
||||
# 确保返回的标签是有效的
|
||||
if emotion in self.EMOTION_KEYWORDS:
|
||||
print(f"\033[1;32m[成功]\033[0m 识别到的情感: {emotion}")
|
||||
return [emotion] # 返回单个情感标签的列表
|
||||
|
||||
return ['neutral'] # 如果无法识别情感,返回neutral
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 情感分析失败: {str(e)}")
|
||||
return ['neutral']
|
||||
|
||||
async def get_emoji_for_emotion(self, emotion_tag: str) -> Optional[str]:
|
||||
try:
|
||||
self._ensure_db()
|
||||
|
||||
# 构建查询条件:标签匹配任一情感
|
||||
query = {'tags': {'$in': emotion_tag}}
|
||||
|
||||
# print(f"\033[1;34m[调试]\033[0m 表情查询条件: {query}")
|
||||
|
||||
try:
|
||||
# 随机获取一个匹配的表情
|
||||
emoji = self.db.db.emoji.aggregate([
|
||||
{'$match': query},
|
||||
{'$sample': {'size': 1}}
|
||||
]).next()
|
||||
print(f"\033[1;32m[成功]\033[0m 找到匹配的表情")
|
||||
if emoji and 'path' in emoji:
|
||||
# 更新使用次数
|
||||
self.db.db.emoji.update_one(
|
||||
{'_id': emoji['_id']},
|
||||
{'$inc': {'usage_count': 1}}
|
||||
)
|
||||
return emoji['path']
|
||||
except StopIteration:
|
||||
# 如果没有匹配的表情,从所有表情中随机选择一个
|
||||
print(f"\033[1;33m[提示]\033[0m 未找到匹配的表情,随机选择一个")
|
||||
try:
|
||||
emoji = self.db.db.emoji.aggregate([
|
||||
{'$sample': {'size': 1}}
|
||||
]).next()
|
||||
if emoji and 'path' in emoji:
|
||||
# 更新使用次数
|
||||
self.db.db.emoji.update_one(
|
||||
{'_id': emoji['_id']},
|
||||
{'$inc': {'usage_count': 1}}
|
||||
)
|
||||
return emoji['path']
|
||||
except StopIteration:
|
||||
print(f"\033[1;31m[错误]\033[0m 数据库中没有任何表情")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 获取表情包失败: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_emoji_for_text(self, text: str) -> Optional[str]:
|
||||
"""根据文本内容获取相关表情包
|
||||
Args:
|
||||
text: 输入文本
|
||||
Returns:
|
||||
Optional[str]: 表情包文件路径,如果没有找到则返回None
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
# 获取情感标签
|
||||
emotions = await self._get_emotion_from_text(text)
|
||||
print("为 ‘"+ str(text) + "’ 获取到的情感标签为:" + str(emotions))
|
||||
if not emotions:
|
||||
return None
|
||||
|
||||
# 构建查询条件:标签匹配任一情感
|
||||
query = {'tags': {'$in': emotions}}
|
||||
|
||||
print(f"\033[1;34m[调试]\033[0m 表情查询条件: {query}")
|
||||
print(f"\033[1;34m[调试]\033[0m 匹配到的情感: {emotions}")
|
||||
|
||||
try:
|
||||
# 随机获取一个匹配的表情
|
||||
emoji = self.db.db.emoji.aggregate([
|
||||
{'$match': query},
|
||||
{'$sample': {'size': 1}}
|
||||
]).next()
|
||||
print(f"\033[1;32m[成功]\033[0m 找到匹配的表情")
|
||||
if emoji and 'path' in emoji:
|
||||
# 更新使用次数
|
||||
self.db.db.emoji.update_one(
|
||||
{'_id': emoji['_id']},
|
||||
{'$inc': {'usage_count': 1}}
|
||||
)
|
||||
return emoji['path']
|
||||
except StopIteration:
|
||||
# 如果没有匹配的表情,从所有表情中随机选择一个
|
||||
print(f"\033[1;33m[提示]\033[0m 未找到匹配的表情,随机选择一个")
|
||||
try:
|
||||
emoji = self.db.db.emoji.aggregate([
|
||||
{'$sample': {'size': 1}}
|
||||
]).next()
|
||||
if emoji and 'path' in emoji:
|
||||
# 更新使用次数
|
||||
self.db.db.emoji.update_one(
|
||||
{'_id': emoji['_id']},
|
||||
{'$inc': {'usage_count': 1}}
|
||||
)
|
||||
return emoji['path']
|
||||
except StopIteration:
|
||||
print(f"\033[1;31m[错误]\033[0m 数据库中没有任何表情")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 获取表情包失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _get_emoji_tag(self, image_base64: str) -> str:
|
||||
"""获取表情包的标签"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "deepseek-ai/deepseek-vl2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签,不要输出其他任何内容,只输出情感标签就好'
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"max_tokens": 60,
|
||||
"temperature": 0.3
|
||||
}
|
||||
|
||||
async with session.post(
|
||||
f"{llm_config.SILICONFLOW_BASE_URL}chat/completions",
|
||||
headers=headers,
|
||||
json=payload
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
tag_result = result["choices"][0]["message"]["content"].strip().lower()
|
||||
|
||||
valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"]
|
||||
for tag_match in valid_tags:
|
||||
if tag_match in tag_result or tag_match == tag_result:
|
||||
return tag_match
|
||||
print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_match}, 跳过")
|
||||
else:
|
||||
print(f"\033[1;31m[错误]\033[0m 获取标签失败, 状态码: {response.status}")
|
||||
|
||||
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
|
||||
return "skip" # 默认标签
|
||||
|
||||
async def scan_new_emojis(self):
|
||||
"""扫描新的表情包"""
|
||||
try:
|
||||
emoji_dir = "data/emoji"
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
|
||||
# 获取所有jpg文件
|
||||
files_to_process = [f for f in os.listdir(emoji_dir) if f.endswith('.jpg')]
|
||||
|
||||
for filename in files_to_process:
|
||||
# 检查是否已经注册过
|
||||
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
|
||||
if existing_emoji:
|
||||
continue
|
||||
|
||||
image_path = os.path.join(emoji_dir, filename)
|
||||
# 读取图片数据
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
# 将图片转换为base64
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
# 获取表情包的情感标签
|
||||
tag = await self._get_emoji_tag(image_base64)
|
||||
if not tag == "skip":
|
||||
# 准备数据库记录
|
||||
emoji_record = {
|
||||
'filename': filename,
|
||||
'path': image_path,
|
||||
'tags': [tag],
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
|
||||
# 保存到数据库
|
||||
self.db.db['emoji'].insert_one(emoji_record)
|
||||
print(f"\033[1;32m[成功]\033[0m 注册新表情包: {filename}")
|
||||
print(f"标签: {tag}")
|
||||
else:
|
||||
print(f"\033[1;33m[警告]\033[0m 跳过表情包: {filename}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 扫描表情包失败: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
async def _periodic_scan(self, interval_MINS: int = 10):
|
||||
"""定期扫描新表情包"""
|
||||
while True:
|
||||
print(f"\033[1;36m[表情包]\033[0m 开始扫描新表情包...")
|
||||
await self.scan_new_emojis()
|
||||
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
|
||||
|
||||
def check_emoji_file_integrity(self):
|
||||
"""检查表情包文件完整性
|
||||
如果文件已被删除,则从数据库中移除对应记录
|
||||
"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
# 获取所有表情包记录
|
||||
all_emojis = list(self.db.db.emoji.find())
|
||||
removed_count = 0
|
||||
total_count = len(all_emojis)
|
||||
|
||||
for emoji in all_emojis:
|
||||
try:
|
||||
if 'path' not in emoji:
|
||||
print(f"\033[1;33m[提示]\033[0m 发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}")
|
||||
self.db.db.emoji.delete_one({'_id': emoji['_id']})
|
||||
removed_count += 1
|
||||
continue
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(emoji['path']):
|
||||
print(f"\033[1;33m[提示]\033[0m 表情包文件已被删除: {emoji['path']}")
|
||||
# 从数据库中删除记录
|
||||
result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
|
||||
if result.deleted_count > 0:
|
||||
print(f"\033[1;32m[成功]\033[0m 成功删除数据库记录: {emoji['_id']}")
|
||||
removed_count += 1
|
||||
else:
|
||||
print(f"\033[1;31m[错误]\033[0m 删除数据库记录失败: {emoji['_id']}")
|
||||
except Exception as item_error:
|
||||
print(f"\033[1;31m[错误]\033[0m 处理表情包记录时出错: {str(item_error)}")
|
||||
continue
|
||||
|
||||
# 验证清理结果
|
||||
remaining_count = self.db.db.emoji.count_documents({})
|
||||
if removed_count > 0:
|
||||
print(f"\033[1;32m[成功]\033[0m 已清理 {removed_count} 个失效的表情包记录")
|
||||
print(f"\033[1;34m[统计]\033[0m 清理前总数: {total_count} | 清理后总数: {remaining_count}")
|
||||
# print(f"\033[1;34m[统计]\033[0m 应删除数量: {removed_count} | 实际删除数量: {total_count - remaining_count}")
|
||||
# 执行数据库压缩
|
||||
try:
|
||||
self.db.db.command({"compact": "emoji"})
|
||||
print(f"\033[1;32m[成功]\033[0m 数据库集合压缩完成")
|
||||
except Exception as compact_error:
|
||||
print(f"\033[1;31m[错误]\033[0m 数据库压缩失败: {str(compact_error)}")
|
||||
else:
|
||||
print(f"\033[1;36m[表情包]\033[0m 已检查 {total_count} 个表情包记录")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 检查表情包完整性失败: {str(e)}")
|
||||
import traceback
|
||||
print(f"\033[1;31m[错误追踪]\033[0m\n{traceback.format_exc()}")
|
||||
|
||||
async def start_periodic_check(self, interval_MINS: int = 120):
|
||||
while True:
|
||||
self.check_emoji_file_integrity()
|
||||
await asyncio.sleep(interval_MINS * 60)
|
||||
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
emoji_manager = EmojiManager()
|
||||
544
src/plugins/chat/gpt_response.py
Normal file
544
src/plugins/chat/gpt_response.py
Normal file
@@ -0,0 +1,544 @@
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple
|
||||
from openai import OpenAI
|
||||
import asyncio
|
||||
import requests
|
||||
from functools import partial
|
||||
from .message import Message
|
||||
from .config import BotConfig
|
||||
from ...common.database import Database
|
||||
import random
|
||||
import time
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import queue
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from .relationship_manager import relationship_manager
|
||||
from ..schedule.schedule_generator import bot_schedule
|
||||
from .prompt_builder import prompt_builder
|
||||
from .config import llm_config
|
||||
from .willing_manager import willing_manager
|
||||
from .utils import get_embedding
|
||||
import aiohttp
|
||||
|
||||
|
||||
# 获取当前文件的绝对路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||
load_dotenv(os.path.join(root_dir, '.env'))
|
||||
|
||||
# 常见的错别字映射
|
||||
TYPO_DICT = {
|
||||
'的': '地得',
|
||||
'了': '咯啦勒',
|
||||
'吗': '嘛麻',
|
||||
'吧': '八把罢',
|
||||
'是': '事',
|
||||
'在': '再在',
|
||||
'和': '合',
|
||||
'有': '又',
|
||||
'我': '沃窝喔',
|
||||
'你': '泥尼拟',
|
||||
'他': '它她塔祂',
|
||||
'们': '门',
|
||||
'啊': '阿哇',
|
||||
'呢': '呐捏',
|
||||
'都': '豆读毒',
|
||||
'很': '狠',
|
||||
'会': '回汇',
|
||||
'去': '趣取曲',
|
||||
'做': '作坐',
|
||||
'想': '相像',
|
||||
'说': '说税睡',
|
||||
'看': '砍堪刊',
|
||||
'来': '来莱赖',
|
||||
'好': '号毫豪',
|
||||
'给': '给既继',
|
||||
'过': '锅果裹',
|
||||
'能': '嫩',
|
||||
'为': '位未',
|
||||
'什': '甚深伸',
|
||||
'么': '末麽嘛',
|
||||
'话': '话花划',
|
||||
'知': '织直值',
|
||||
'道': '到',
|
||||
'听': '听停挺',
|
||||
'见': '见件建',
|
||||
'觉': '觉脚搅',
|
||||
'得': '得德锝',
|
||||
'着': '着找招',
|
||||
'像': '向象想',
|
||||
'等': '等灯登',
|
||||
'谢': '谢写卸',
|
||||
'对': '对队',
|
||||
'里': '里理鲤',
|
||||
'啦': '啦拉喇',
|
||||
'吃': '吃持迟',
|
||||
'哦': '哦喔噢',
|
||||
'呀': '呀压',
|
||||
'要': '药',
|
||||
'太': '太抬台',
|
||||
'快': '块',
|
||||
'点': '店',
|
||||
'以': '以已',
|
||||
'因': '因应',
|
||||
'啥': '啥沙傻',
|
||||
'行': '行型形',
|
||||
'哈': '哈蛤铪',
|
||||
'嘿': '嘿黑嗨',
|
||||
'嗯': '嗯恩摁',
|
||||
'哎': '哎爱埃',
|
||||
'呜': '呜屋污',
|
||||
'喂': '喂位未',
|
||||
'嘛': '嘛麻马',
|
||||
'嗨': '嗨害亥',
|
||||
'哇': '哇娃蛙',
|
||||
'咦': '咦意易',
|
||||
'嘻': '嘻西希'
|
||||
}
|
||||
|
||||
def random_remove_punctuation(text: str) -> str:
|
||||
"""随机处理标点符号,模拟人类打字习惯"""
|
||||
result = ''
|
||||
text_len = len(text)
|
||||
|
||||
for i, char in enumerate(text):
|
||||
if char == '。' and i == text_len - 1: # 结尾的句号
|
||||
if random.random() > 0.4: # 80%概率删除结尾句号
|
||||
continue
|
||||
elif char == ',':
|
||||
rand = random.random()
|
||||
if rand < 0.25: # 5%概率删除逗号
|
||||
continue
|
||||
elif rand < 0.25: # 20%概率把逗号变成空格
|
||||
result += ' '
|
||||
continue
|
||||
result += char
|
||||
return result
|
||||
|
||||
def add_typos(text: str) -> str:
|
||||
"""随机给文本添加错别字"""
|
||||
TYPO_RATE = 0.02 # 控制错别字出现的概率(1%)
|
||||
|
||||
result = ""
|
||||
for char in text:
|
||||
if char in TYPO_DICT and random.random() < TYPO_RATE:
|
||||
# 从可能的错别字中随机选择一个
|
||||
typos = TYPO_DICT[char]
|
||||
result += random.choice(typos)
|
||||
else:
|
||||
result += char
|
||||
return result
|
||||
|
||||
def open_new_console_window(text: str):
|
||||
"""在新的控制台窗口中显示文本"""
|
||||
if sys.platform == 'win32':
|
||||
# 创建一个临时批处理文件
|
||||
temp_bat = "temp_output.bat"
|
||||
with open(temp_bat, "w", encoding="utf-8") as f:
|
||||
f.write(f'@echo off\n')
|
||||
f.write(f'echo {text}\n')
|
||||
f.write('pause\n')
|
||||
|
||||
# 在新窗口中运行批处理文件
|
||||
subprocess.Popen(['start', 'cmd', '/c', temp_bat], shell=True)
|
||||
|
||||
# 等待一会儿再删除批处理文件
|
||||
import threading
|
||||
def delete_bat():
|
||||
import time
|
||||
time.sleep(2)
|
||||
if os.path.exists(temp_bat):
|
||||
os.remove(temp_bat)
|
||||
threading.Thread(target=delete_bat).start()
|
||||
|
||||
class ReasoningWindow:
|
||||
def __init__(self):
|
||||
self.process = None
|
||||
self.message_queue = queue.Queue()
|
||||
self.is_running = False
|
||||
self.content_file = "reasoning_content.txt"
|
||||
|
||||
def start(self):
|
||||
if self.process is None:
|
||||
# 创建用于显示的批处理文件
|
||||
with open("reasoning_window.bat", "w", encoding="utf-8") as f:
|
||||
f.write('@echo off\n')
|
||||
f.write('chcp 65001\n') # 设置UTF-8编码
|
||||
f.write('title Magellan Reasoning Process\n')
|
||||
f.write('echo Waiting for reasoning content...\n')
|
||||
f.write(':loop\n')
|
||||
f.write('if exist "reasoning_update.txt" (\n')
|
||||
f.write(' type "reasoning_update.txt" >> "reasoning_content.txt"\n')
|
||||
f.write(' del "reasoning_update.txt"\n')
|
||||
f.write(' cls\n')
|
||||
f.write(' type "reasoning_content.txt"\n')
|
||||
f.write(')\n')
|
||||
f.write('timeout /t 1 /nobreak >nul\n')
|
||||
f.write('goto loop\n')
|
||||
|
||||
# 清空内容文件
|
||||
with open(self.content_file, "w", encoding="utf-8") as f:
|
||||
f.write("")
|
||||
|
||||
# 启动新窗口
|
||||
startupinfo = subprocess.STARTUPINFO()
|
||||
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
self.process = subprocess.Popen(['cmd', '/c', 'start', 'reasoning_window.bat'],
|
||||
shell=True,
|
||||
startupinfo=startupinfo)
|
||||
self.is_running = True
|
||||
|
||||
# 启动处理线程
|
||||
threading.Thread(target=self._process_messages, daemon=True).start()
|
||||
|
||||
def _process_messages(self):
|
||||
while self.is_running:
|
||||
try:
|
||||
# 获取新消息
|
||||
text = self.message_queue.get(timeout=1)
|
||||
# 写入更新文件
|
||||
with open("reasoning_update.txt", "w", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"处理推理内容时出错: {e}")
|
||||
|
||||
def update_content(self, text: str):
|
||||
if self.is_running:
|
||||
self.message_queue.put(text)
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
if self.process:
|
||||
self.process.terminate()
|
||||
self.process = None
|
||||
# 清理文件
|
||||
for file in ["reasoning_window.bat", "reasoning_content.txt", "reasoning_update.txt"]:
|
||||
if os.path.exists(file):
|
||||
os.remove(file)
|
||||
|
||||
# 创建全局单例
|
||||
reasoning_window = ReasoningWindow()
|
||||
|
||||
class GPTResponseGenerator:
|
||||
def __init__(self, config: BotConfig):
|
||||
self.config = config
|
||||
self.client = OpenAI(
|
||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
||||
)
|
||||
|
||||
self.db = Database.get_instance()
|
||||
reasoning_window.start()
|
||||
# 当前使用的模型类型
|
||||
self.current_model_type = 'r1' # 默认使用 R1
|
||||
|
||||
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
|
||||
"""根据当前模型类型选择对应的生成函数"""
|
||||
# 使用随机数选择模型
|
||||
rand = random.random()
|
||||
if rand < 0.15: # 40%概率使用 R1
|
||||
self.current_model_type = "r1"
|
||||
elif rand < 0.8: # 30%概率使用 V3
|
||||
self.current_model_type = "v3"
|
||||
else: # 30%概率使用 R1-Distill
|
||||
self.current_model_type = "r1_distill"
|
||||
|
||||
print(f"+++++++++++++++++麦麦{self.current_model_type}思考中+++++++++++++++++")
|
||||
if self.current_model_type == 'r1':
|
||||
model_response = await self._generate_r1_response(message)
|
||||
elif self.current_model_type == 'v3':
|
||||
model_response = await self._generate_v3_response(message)
|
||||
else:
|
||||
model_response = await self._generate_r1_distill_response(message)
|
||||
|
||||
# 打印情感标签
|
||||
print(f'麦麦的回复是:{model_response}')
|
||||
model_response , emotion = await self._process_response(model_response)
|
||||
|
||||
if model_response:
|
||||
print(f"为 '{model_response}' 获取到的情感标签为:{emotion}")
|
||||
|
||||
return model_response,emotion
|
||||
|
||||
async def _generate_r1_response(self, message: Message) -> Optional[Tuple[Union[str, List[str]], List[str]]]:
|
||||
"""使用 DeepSeek-R1 模型生成回复"""
|
||||
# 获取群聊上下文
|
||||
group_chat = await self._get_group_chat_context(message)
|
||||
sender_name = message.user_nickname or f"用户{message.user_id}"
|
||||
if relationship_manager.get_relationship(message.user_id):
|
||||
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value
|
||||
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
||||
else:
|
||||
relationship_value = 0.0
|
||||
|
||||
# 构建 prompt
|
||||
prompt = prompt_builder._build_prompt(
|
||||
message_txt=message.processed_plain_text,
|
||||
sender_name=sender_name,
|
||||
relationship_value=relationship_value,
|
||||
group_id=message.group_id
|
||||
)
|
||||
|
||||
def create_completion():
|
||||
return self.client.chat.completions.create(
|
||||
model="Pro/deepseek-ai/DeepSeek-R1",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=False,
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(None, create_completion)
|
||||
if response.choices[0].message.content:
|
||||
print(response.choices[0].message.content)
|
||||
print(response.choices[0].message.reasoning_content)
|
||||
# 处理 R1 特有的返回格式
|
||||
content = response.choices[0].message.content
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
else:
|
||||
return None
|
||||
# 更新推理窗口
|
||||
self._update_reasoning_window(message, prompt, reasoning_content, content, sender_name)
|
||||
|
||||
return content
|
||||
|
||||
async def _generate_v3_response(self, message: Message) -> Optional[Tuple[Union[str, List[str]], List[str]]]:
|
||||
"""使用 DeepSeek-V3 模型生成回复"""
|
||||
# 获取群聊上下文
|
||||
group_chat = await self._get_group_chat_context(message)
|
||||
sender_name = message.user_nickname or f"用户{message.user_id}"
|
||||
|
||||
if relationship_manager.get_relationship(message.user_id):
|
||||
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value
|
||||
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
||||
else:
|
||||
relationship_value = 0.0
|
||||
|
||||
prompt = prompt_builder._build_prompt(message.processed_plain_text, sender_name,relationship_value,group_id=message.group_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
create_completion = partial(
|
||||
self.client.chat.completions.create,
|
||||
model="Pro/deepseek-ai/DeepSeek-V3",
|
||||
messages=messages,
|
||||
stream=False,
|
||||
max_tokens=1024,
|
||||
temperature=0.8
|
||||
)
|
||||
response = await loop.run_in_executor(None, create_completion)
|
||||
|
||||
if response.choices[0].message.content:
|
||||
content = response.choices[0].message.content
|
||||
# V3 模型没有 reasoning_content
|
||||
self._update_reasoning_window(message, prompt, "V3模型无推理过程", content, sender_name)
|
||||
return content
|
||||
else:
|
||||
print(f"[ERROR] V3 回复发送生成失败: {response}")
|
||||
|
||||
return None, [] # 返回元组
|
||||
|
||||
async def _generate_r1_distill_response(self, message: Message) -> Optional[Tuple[Union[str, List[str]], List[str]]]:
|
||||
"""使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复"""
|
||||
# 获取群聊上下文
|
||||
group_chat = await self._get_group_chat_context(message)
|
||||
sender_name = message.user_nickname or f"用户{message.user_id}"
|
||||
if relationship_manager.get_relationship(message.user_id):
|
||||
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value
|
||||
print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
||||
else:
|
||||
relationship_value = 0.0
|
||||
|
||||
# 构建 prompt
|
||||
prompt = prompt_builder._build_prompt(
|
||||
message_txt=message.processed_plain_text,
|
||||
sender_name=sender_name,
|
||||
relationship_value=relationship_value,
|
||||
group_id=message.group_id
|
||||
)
|
||||
|
||||
def create_completion():
|
||||
return self.client.chat.completions.create(
|
||||
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=False,
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(None, create_completion)
|
||||
if response.choices[0].message.content:
|
||||
print(response.choices[0].message.content)
|
||||
print(response.choices[0].message.reasoning_content)
|
||||
# 处理 R1 特有的返回格式
|
||||
content = response.choices[0].message.content
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
else:
|
||||
return None
|
||||
# 更新推理窗口
|
||||
self._update_reasoning_window(message, prompt, reasoning_content, content, sender_name)
|
||||
|
||||
return content
|
||||
|
||||
async def _get_group_chat_context(self, message: Message) -> str:
|
||||
"""获取群聊上下文"""
|
||||
recent_messages = self.db.db.messages.find(
|
||||
{"group_id": message.group_id}
|
||||
).sort("time", -1).limit(15)
|
||||
|
||||
messages_list = list(recent_messages)[::-1]
|
||||
group_chat = ""
|
||||
|
||||
for msg_dict in messages_list:
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(msg_dict['time']))
|
||||
display_name = msg_dict.get('user_nickname', f"用户{msg_dict['user_id']}")
|
||||
content = msg_dict.get('processed_plain_text', msg_dict['plain_text'])
|
||||
|
||||
group_chat += f"[{time_str}] {display_name}: {content}\n"
|
||||
|
||||
return group_chat
|
||||
|
||||
def _update_reasoning_window(self, message, prompt, reasoning_content, content, sender_name):
|
||||
"""更新推理窗口内容"""
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# 获取当前使用的模型名称
|
||||
model_name = {
|
||||
'r1': 'DeepSeek-R1',
|
||||
'v3': 'DeepSeek-V3',
|
||||
'r1_distill': 'DeepSeek-R1-Distill-Qwen-32B'
|
||||
}.get(self.current_model_type, '未知模型')
|
||||
|
||||
display_text = (
|
||||
f"Time: {current_time}\n"
|
||||
f"Group: {message.group_name}\n"
|
||||
f"User: {sender_name}\n"
|
||||
f"Model: {model_name}\n"
|
||||
f"\033[1;32mMessage:\033[0m {message.processed_plain_text}\n\n"
|
||||
f"\033[1;32mPrompt:\033[0m \n{prompt}\n"
|
||||
f"\n-------------------------------------------------------"
|
||||
f"\n\033[1;32mReasoning Process:\033[0m\n{reasoning_content}\n"
|
||||
f"\n\033[1;32mResponse Content:\033[0m\n{content}\n"
|
||||
f"\n{'='*50}\n"
|
||||
)
|
||||
reasoning_window.update_content(display_text)
|
||||
|
||||
async def _get_emotion_tags(self, content: str) -> List[str]:
|
||||
"""提取情感标签"""
|
||||
try:
|
||||
prompt = f'''请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
|
||||
只输出标签就好,不要输出其他内容:
|
||||
内容:{content}
|
||||
输出:
|
||||
'''
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
create_completion = partial(
|
||||
self.client.chat.completions.create,
|
||||
model="Pro/deepseek-ai/DeepSeek-V3",
|
||||
messages=messages,
|
||||
stream=False,
|
||||
max_tokens=30,
|
||||
temperature=0.6
|
||||
)
|
||||
response = await loop.run_in_executor(None, create_completion)
|
||||
|
||||
if response.choices[0].message.content:
|
||||
# 确保返回的是列表格式
|
||||
emotion_tag = response.choices[0].message.content.strip()
|
||||
return [emotion_tag] # 将单个标签包装成列表返回
|
||||
|
||||
return ["neutral"] # 如果无法获取情感标签,返回默认值
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取情感标签时出错: {e}")
|
||||
return ["neutral"] # 发生错误时返回默认值
|
||||
|
||||
async def _process_response(self, content: str) -> Tuple[Union[str, List[str]], List[str]]:
|
||||
"""处理响应内容,返回处理后的内容和情感标签"""
|
||||
if not content:
|
||||
return None, []
|
||||
|
||||
emotion_tags = await self._get_emotion_tags(content)
|
||||
|
||||
# 添加错别字和处理标点符号
|
||||
if random.random() < 0.9: # 90%概率进行处理
|
||||
processed_response = random_remove_punctuation(add_typos(content))
|
||||
else:
|
||||
processed_response = content
|
||||
# 处理长消息
|
||||
if len(processed_response) > 5:
|
||||
sentences = self._split_into_sentences(processed_response)
|
||||
print(f"分割后的句子: {sentences}")
|
||||
messages = []
|
||||
current_message = ""
|
||||
|
||||
for sentence in sentences:
|
||||
if len(current_message) + len(sentence) <= 5:
|
||||
current_message += ' '
|
||||
current_message += sentence
|
||||
else:
|
||||
if current_message:
|
||||
messages.append(current_message.strip())
|
||||
current_message = sentence
|
||||
|
||||
if current_message:
|
||||
messages.append(current_message.strip())
|
||||
|
||||
# 翻转消息顺序
|
||||
# messages.reverse()
|
||||
|
||||
return messages, emotion_tags
|
||||
|
||||
return processed_response, emotion_tags
|
||||
|
||||
def _split_into_sentences(self, text: str) -> List[str]:
|
||||
"""将文本分割成句子,但保持书名号中的内容完整"""
|
||||
delimiters = ['。', '!', ',', ',', '?', '…', '!', '?', '\n'] # 添加换行符作为分隔符
|
||||
remove_chars = [',', ','] # 只移除这两种逗号
|
||||
sentences = []
|
||||
current_sentence = ""
|
||||
in_book_title = False # 标记是否在书名号内
|
||||
|
||||
for char in text:
|
||||
current_sentence += char
|
||||
|
||||
# 检查书名号
|
||||
if char == '《':
|
||||
in_book_title = True
|
||||
elif char == '》':
|
||||
in_book_title = False
|
||||
|
||||
# 只有不在书名号内且是分隔符时才分割
|
||||
if char in delimiters and not in_book_title:
|
||||
if current_sentence.strip(): # 确保不是空字符串
|
||||
# 只移除逗号
|
||||
clean_sentence = current_sentence
|
||||
if clean_sentence[-1] in remove_chars:
|
||||
clean_sentence = clean_sentence[:-1]
|
||||
if clean_sentence.strip():
|
||||
sentences.append(clean_sentence.strip())
|
||||
current_sentence = ""
|
||||
|
||||
# 处理最后一个句子
|
||||
if current_sentence.strip():
|
||||
# 如果最后一个字符是逗号,移除它
|
||||
if current_sentence[-1] in remove_chars:
|
||||
current_sentence = current_sentence[:-1]
|
||||
sentences.append(current_sentence.strip())
|
||||
|
||||
# 过滤掉空字符串
|
||||
sentences = [s for s in sentences if s.strip()]
|
||||
|
||||
return sentences
|
||||
|
||||
|
||||
# llm_response = GPTResponseGenerator(config=BotConfig())
|
||||
107
src/plugins/chat/group_info_manager.py
Normal file
107
src/plugins/chat/group_info_manager.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from typing import Dict, Optional
|
||||
from ...common.database import Database
|
||||
import time
|
||||
|
||||
class GroupInfoManager:
|
||||
def __init__(self):
|
||||
self.db = Database.get_instance()
|
||||
# 确保必要的集合存在
|
||||
self._ensure_collections()
|
||||
|
||||
def _ensure_collections(self):
|
||||
"""确保数据库中有必要的集合"""
|
||||
collections = self.db.db.list_collection_names()
|
||||
if 'group_info' not in collections:
|
||||
self.db.db.create_collection('group_info')
|
||||
if 'user_info' not in collections:
|
||||
self.db.db.create_collection('user_info')
|
||||
|
||||
async def update_group_info(self, group_id: int, group_name: str, group_notice: str = "",
|
||||
member_count: int = 0, admins: list = None):
|
||||
"""更新群组信息"""
|
||||
try:
|
||||
group_data = {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"group_notice": group_notice,
|
||||
"member_count": member_count,
|
||||
"admins": admins or [],
|
||||
"last_updated": time.time()
|
||||
}
|
||||
|
||||
# 使用 upsert 来更新或插入数据
|
||||
self.db.db.group_info.update_one(
|
||||
{"group_id": group_id},
|
||||
{"$set": group_data},
|
||||
upsert=True
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 更新群信息失败: {str(e)}")
|
||||
|
||||
async def update_user_info(self, user_id: int, nickname: str, group_id: int = None,
|
||||
group_card: str = None, age: int = None, gender: str = None,
|
||||
location: str = None):
|
||||
"""更新用户信息"""
|
||||
try:
|
||||
# 基础用户数据
|
||||
user_data = {
|
||||
"user_id": user_id,
|
||||
"nickname": nickname,
|
||||
"last_updated": time.time()
|
||||
}
|
||||
|
||||
# 添加可选字段
|
||||
if age is not None:
|
||||
user_data["age"] = age
|
||||
if gender is not None:
|
||||
user_data["gender"] = gender
|
||||
if location is not None:
|
||||
user_data["location"] = location
|
||||
|
||||
# 如果提供了群相关信息,更新用户在该群的信息
|
||||
if group_id is not None:
|
||||
group_info_key = f"group_info.{group_id}"
|
||||
group_data = {
|
||||
group_info_key: {
|
||||
"group_card": group_card,
|
||||
"last_active": time.time()
|
||||
}
|
||||
}
|
||||
user_data.update(group_data)
|
||||
|
||||
# 使用 upsert 来更新或插入数据
|
||||
result = self.db.db.user_info.update_one(
|
||||
{"user_id": user_id},
|
||||
{
|
||||
"$set": user_data,
|
||||
"$addToSet": {"groups": group_id} if group_id else {}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# print(f"\033[1;32m[用户信息]\033[0m 更新用户 {nickname}({user_id}) 的信息 {'成功' if result.modified_count > 0 or result.upserted_id else '未变化'}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 更新用户信息失败: {str(e)}")
|
||||
print(f"用户ID: {user_id}, 昵称: {nickname}, 群ID: {group_id}, 群名片: {group_card}")
|
||||
|
||||
async def get_group_info(self, group_id: int) -> Optional[Dict]:
|
||||
"""获取群组信息"""
|
||||
try:
|
||||
return self.db.db.group_info.find_one({"group_id": group_id})
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 获取群信息失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_user_info(self, user_id: int, group_id: int = None) -> Optional[Dict]:
|
||||
"""获取用户信息"""
|
||||
try:
|
||||
user_info = self.db.db.user_info.find_one({"user_id": user_id})
|
||||
if user_info and group_id:
|
||||
# 添加该用户在特定群的信息
|
||||
group_info_key = f"group_info.{group_id}"
|
||||
user_info["current_group_info"] = user_info.get(group_info_key, {})
|
||||
return user_info
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 获取用户信息失败: {str(e)}")
|
||||
return None
|
||||
162
src/plugins/chat/image_utils.py
Normal file
162
src/plugins/chat/image_utils.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import io
|
||||
from PIL import Image
|
||||
import hashlib
|
||||
import time
|
||||
import os
|
||||
from ...common.database import Database
|
||||
from .config import BotConfig
|
||||
import zlib # 用于 CRC32
|
||||
import base64
|
||||
|
||||
bot_config = BotConfig.load_config()
|
||||
|
||||
def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
||||
"""
|
||||
压缩图片到指定大小(单位:KB)并在数据库中记录图片信息
|
||||
Args:
|
||||
image_data: 图片字节数据
|
||||
group_id: 群组ID
|
||||
user_id: 用户ID
|
||||
max_size: 最大文件大小(KB)
|
||||
"""
|
||||
try:
|
||||
# 使用 CRC32 计算哈希值
|
||||
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
|
||||
|
||||
# 确保图片目录存在
|
||||
images_dir = "data/images"
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
|
||||
# 连接数据库
|
||||
db = Database(
|
||||
host=bot_config.MONGODB_HOST,
|
||||
port=bot_config.MONGODB_PORT,
|
||||
db_name=bot_config.DATABASE_NAME
|
||||
)
|
||||
|
||||
# 检查是否已存在相同哈希值的图片
|
||||
collection = db.db['images']
|
||||
existing_image = collection.find_one({'hash': hash_value})
|
||||
|
||||
if existing_image:
|
||||
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
|
||||
return image_data
|
||||
|
||||
# 将字节数据转换为图片对象
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 如果是动图,直接返回原图
|
||||
if getattr(img, 'is_animated', False):
|
||||
return image_data
|
||||
|
||||
# 计算当前大小(KB)
|
||||
current_size = len(image_data) / 1024
|
||||
|
||||
# 如果已经小于目标大小,直接使用原图
|
||||
if current_size <= max_size:
|
||||
compressed_data = image_data
|
||||
else:
|
||||
# 压缩逻辑
|
||||
# 先缩放到50%
|
||||
new_width = int(img.width * 0.5)
|
||||
new_height = int(img.height * 0.5)
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 如果缩放后的最大边长仍然大于400,继续缩放
|
||||
max_dimension = 400
|
||||
max_current = max(new_width, new_height)
|
||||
if max_current > max_dimension:
|
||||
ratio = max_dimension / max_current
|
||||
new_width = int(new_width * ratio)
|
||||
new_height = int(new_height * ratio)
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为RGB模式(去除透明通道)
|
||||
if img.mode in ('RGBA', 'P'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
# 使用固定质量参数压缩
|
||||
output = io.BytesIO()
|
||||
img.save(output, format='JPEG', quality=85, optimize=True)
|
||||
compressed_data = output.getvalue()
|
||||
|
||||
# 生成文件名(使用时间戳和哈希值确保唯一性)
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{hash_value}.jpg"
|
||||
image_path = os.path.join(images_dir, filename)
|
||||
|
||||
# 保存文件
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(compressed_data)
|
||||
|
||||
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
|
||||
|
||||
try:
|
||||
# 准备数据库记录
|
||||
image_record = {
|
||||
'filename': filename,
|
||||
'path': image_path,
|
||||
'size': len(compressed_data) / 1024,
|
||||
'timestamp': timestamp,
|
||||
'width': img.width,
|
||||
'height': img.height,
|
||||
'description': '',
|
||||
'tags': [],
|
||||
'type': 'image',
|
||||
'hash': hash_value
|
||||
}
|
||||
|
||||
# 保存记录
|
||||
collection.insert_one(image_record)
|
||||
print(f"\033[1;32m[成功]\033[0m 保存图片记录到数据库")
|
||||
|
||||
except Exception as db_error:
|
||||
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
|
||||
|
||||
return compressed_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return image_data
|
||||
|
||||
def storage_emoji(image_data: bytes) -> bytes:
|
||||
"""
|
||||
存储表情包到本地文件夹
|
||||
Args:
|
||||
image_data: 图片字节数据
|
||||
group_id: 群组ID(仅用于日志)
|
||||
user_id: 用户ID(仅用于日志)
|
||||
Returns:
|
||||
bytes: 原始图片数据
|
||||
"""
|
||||
try:
|
||||
# 使用 CRC32 计算哈希值
|
||||
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
|
||||
|
||||
# 确保表情包目录存在
|
||||
emoji_dir = "data/emoji"
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
|
||||
# 检查是否已存在相同哈希值的文件
|
||||
for filename in os.listdir(emoji_dir):
|
||||
if hash_value in filename:
|
||||
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
|
||||
return image_data
|
||||
|
||||
# 生成文件名
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{hash_value}.jpg"
|
||||
emoji_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
# 直接保存原始文件
|
||||
with open(emoji_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
|
||||
print(f"\033[1;32m[成功]\033[0m 保存表情包到: {emoji_path}")
|
||||
return image_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 保存表情包失败: {str(e)}")
|
||||
return image_data
|
||||
76
src/plugins/chat/info_gui.py
Normal file
76
src/plugins/chat/info_gui.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import gradio as gr
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, List
|
||||
from .message import Message
|
||||
|
||||
class MessageWindow:
|
||||
def __init__(self):
|
||||
self.interface = None
|
||||
self._running = False
|
||||
self.messages_history = []
|
||||
|
||||
def _create_window(self):
|
||||
"""创建Gradio界面"""
|
||||
with gr.Blocks(title="实时消息监控") as self.interface:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.message_box = gr.Dataframe(
|
||||
headers=["时间", "群号", "发送者", "消息内容"],
|
||||
datatype=["str", "str", "str", "str"],
|
||||
row_count=20,
|
||||
col_count=(4, "fixed"),
|
||||
interactive=False,
|
||||
wrap=True
|
||||
)
|
||||
|
||||
# 每1秒自动刷新
|
||||
self.interface.load(self._update_display, None, [self.message_box], every=1)
|
||||
|
||||
# 启动界面
|
||||
self.interface.queue()
|
||||
self._running = True
|
||||
self.interface.launch(share=False, server_port=7860)
|
||||
|
||||
def _update_display(self):
|
||||
"""更新消息显示"""
|
||||
display_data = []
|
||||
for msg in self.messages_history[-1000:]: # 只显示最近1000条消息
|
||||
time_str = time.strftime("%H:%M:%S", time.localtime(msg["time"]))
|
||||
display_data.append([
|
||||
time_str,
|
||||
str(msg["group_id"]),
|
||||
f"{msg['user_nickname']}({msg['user_id']})",
|
||||
msg["plain_text"]
|
||||
])
|
||||
return display_data
|
||||
|
||||
def update_messages(self, group_id: int, messages: List[Message]):
|
||||
"""接收新消息更新"""
|
||||
for msg in messages:
|
||||
self.messages_history.append({
|
||||
"time": msg.time,
|
||||
"group_id": group_id,
|
||||
"user_id": msg.user_id,
|
||||
"user_nickname": msg.user_nickname,
|
||||
"plain_text": msg.plain_text
|
||||
})
|
||||
|
||||
# 保持最多存储1000条消息
|
||||
if len(self.messages_history) > 1000:
|
||||
self.messages_history = self.messages_history[-1000:]
|
||||
|
||||
def start(self):
|
||||
"""启动窗口"""
|
||||
# 在新线程中启动窗口
|
||||
threading.Thread(target=self._create_window, daemon=True).start()
|
||||
|
||||
def stop(self):
|
||||
"""停止窗口"""
|
||||
self._running = False
|
||||
if self.interface:
|
||||
self.interface.close()
|
||||
|
||||
# 创建全局实例
|
||||
message_window = MessageWindow()
|
||||
|
||||
108
src/plugins/chat/llm_generator.py
Normal file
108
src/plugins/chat/llm_generator.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Dict, Any, List, Optional, Union, Tuple
|
||||
from openai import OpenAI
|
||||
from functools import partial
|
||||
from .config import BotConfig
|
||||
from ...common.database import Database
|
||||
import random
|
||||
import os
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from .relationship_manager import relationship_manager
|
||||
|
||||
# 获取当前文件的绝对路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||
load_dotenv(os.path.join(root_dir, '.env'))
|
||||
|
||||
|
||||
|
||||
class LLMResponseGenerator:
|
||||
def __init__(self, config: BotConfig):
|
||||
self.config = config
|
||||
self.API_KEY = os.getenv('SILICONFLOW_KEY')
|
||||
self.BASE_URL =os.getenv('SILICONFLOW_BASE_URL')
|
||||
self.client = OpenAI(
|
||||
api_key=self.API_KEY,
|
||||
base_url=self.BASE_URL
|
||||
)
|
||||
|
||||
self.db = Database.get_instance()
|
||||
# 当前使用的模型类型
|
||||
self.current_model_type = 'r1' # 默认使用 R1
|
||||
|
||||
async def generate_response(self, text: str) -> Optional[str]:
|
||||
"""根据当前模型类型选择对应的生成函数"""
|
||||
if random.random() < self.config.MODEL_R1_PROBABILITY:
|
||||
self.current_model_type = "r1"
|
||||
else:
|
||||
self.current_model_type = "v3"
|
||||
|
||||
print(f"+++++++++++++++++麦麦{self.current_model_type}思考中+++++++++++++++++")
|
||||
if self.current_model_type == 'r1':
|
||||
model_response = await self._generate_v3_response(text)
|
||||
else:
|
||||
model_response = await self._generate_v3_response(text)
|
||||
# 打印情感标签
|
||||
print(f'麦麦的回复------------------------------是:{model_response}')
|
||||
|
||||
return model_response
|
||||
|
||||
async def _generate_r1_response(self, text: str) -> Optional[str]:
|
||||
"""使用 DeepSeek-R1 模型生成回复"""
|
||||
messages = [{"role": "user", "content": text}]
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.API_KEY}"
|
||||
}
|
||||
payload = {
|
||||
"model": "Pro/deepseek-ai/DeepSeek-R1",
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.8
|
||||
}
|
||||
async with session.post(f"{self.BASE_URL}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload) as response:
|
||||
result = await response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
print(f"Content: {content}")
|
||||
print(f"Reasoning: {reasoning_content}")
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
async def _generate_v3_response(self, text: str) -> Optional[str]:
|
||||
"""使用 DeepSeek-V3 模型生成回复"""
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.API_KEY}"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "Pro/deepseek-ai/DeepSeek-V3",
|
||||
"messages": [{"role": "user", "content": text}],
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.8
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{self.BASE_URL}/chat/completions",
|
||||
headers=headers,
|
||||
json=payload) as response:
|
||||
result = await response.json()
|
||||
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
return content
|
||||
else:
|
||||
print(f"[ERROR] V3 回复发送生成失败: {result}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
llm_response = LLMResponseGenerator(config=BotConfig())
|
||||
318
src/plugins/chat/message.py
Normal file
318
src/plugins/chat/message.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Dict, Tuple, ForwardRef
|
||||
import time
|
||||
import jieba.analyse as jieba_analyse
|
||||
import os
|
||||
from datetime import datetime
|
||||
from ...common.database import Database
|
||||
from PIL import Image
|
||||
from .config import BotConfig, global_config
|
||||
import urllib3
|
||||
from .cq_code import CQCode
|
||||
|
||||
Message = ForwardRef('Message') # 添加这行
|
||||
|
||||
# 加载配置
|
||||
bot_config = BotConfig.load_config()
|
||||
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
#这个类是消息数据类,用于存储和管理消息数据。
|
||||
#它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
|
||||
#它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""消息数据类"""
|
||||
group_id: int = None
|
||||
user_id: int = None
|
||||
user_nickname: str = None # 用户昵称
|
||||
group_name: str = None # 群名称
|
||||
|
||||
message_id: int = None
|
||||
raw_message: str = None
|
||||
plain_text: str = None
|
||||
|
||||
message_based_id: int = None
|
||||
reply_message: Dict = None # 存储回复消息
|
||||
|
||||
message_segments: List[Dict] = None # 存储解析后的消息片段
|
||||
processed_plain_text: str = None # 用于存储处理后的plain_text
|
||||
|
||||
time: float = None
|
||||
|
||||
is_emoji: bool = False # 是否是表情包
|
||||
|
||||
|
||||
|
||||
reply_benefits: float = 0.0
|
||||
|
||||
type: str = 'received' # 消息类型,可以是received或者send
|
||||
|
||||
|
||||
|
||||
"""消息数据类:思考消息"""
|
||||
|
||||
# 思考状态相关属性
|
||||
is_thinking: bool = False
|
||||
thinking_text: str = "正在思考..."
|
||||
thingking_start_time: float = None
|
||||
thinking_time: float = 0
|
||||
|
||||
received_message = ''
|
||||
thinking_response = ''
|
||||
|
||||
def __post_init__(self):
|
||||
if self.time is None:
|
||||
self.time = int(time.time())
|
||||
|
||||
if not self.user_nickname:
|
||||
self.user_nickname = self.get_user_nickname(self.user_id)
|
||||
|
||||
if not self.group_name:
|
||||
self.group_name = self.get_groupname(self.group_id)
|
||||
|
||||
if not self.processed_plain_text:
|
||||
# 解析消息片段
|
||||
if self.raw_message:
|
||||
# print(f"\033[1;34m[调试信息]\033[0m 原始消息: {self.raw_message}")
|
||||
self.message_segments = self.parse_message_segments(str(self.raw_message))
|
||||
self.processed_plain_text = ' '.join(
|
||||
seg['translated_text']
|
||||
for seg in self.message_segments
|
||||
)
|
||||
|
||||
# print(f"\033[1;34m[调试]\033[0m pppttt消息: {self.processed_plain_text}")
|
||||
def get_user_nickname(self, user_id: int) -> str:
|
||||
"""
|
||||
根据user_id获取用户昵称
|
||||
如果数据库中找不到,则返回默认昵称
|
||||
"""
|
||||
if not user_id:
|
||||
return "未知用户"
|
||||
|
||||
user_id = int(user_id)
|
||||
if user_id == int(global_config.BOT_QQ):
|
||||
return "麦麦"
|
||||
|
||||
# 使用数据库单例
|
||||
db = Database.get_instance()
|
||||
# 查找用户,打印查询条件和结果
|
||||
query = {'user_id': user_id}
|
||||
user = db.db.user_info.find_one(query)
|
||||
if user:
|
||||
return user.get('nickname') or f"用户{user_id}"
|
||||
else:
|
||||
return f"用户{user_id}"
|
||||
|
||||
def get_groupname(self, group_id: int) -> str:
|
||||
if not group_id:
|
||||
return "未知群"
|
||||
group_id = int(group_id)
|
||||
# 使用数据库单例
|
||||
db = Database.get_instance()
|
||||
# 查找用户,打印查询条件和结果
|
||||
query = {'group_id': group_id}
|
||||
group = db.db.group_info.find_one(query)
|
||||
if group:
|
||||
return group.get('group_name')
|
||||
else:
|
||||
return f"群{group_id}"
|
||||
|
||||
def parse_message_segments(self, message: str) -> List[Dict]:
|
||||
"""
|
||||
将消息解析为片段列表,包括纯文本和CQ码
|
||||
返回的列表中每个元素都是字典,包含:
|
||||
- type: 'text' 或 CQ码类型
|
||||
- data: 对于text类型是文本内容,对于CQ码是参数字典
|
||||
- translated_text: 经过处理(如AI翻译)后的文本
|
||||
"""
|
||||
segments = []
|
||||
start = 0
|
||||
|
||||
while True:
|
||||
# 查找下一个CQ码的开始位置
|
||||
cq_start = message.find('[CQ:', start)
|
||||
if cq_start == -1:
|
||||
# 如果没有找到更多CQ码,添加剩余文本
|
||||
if start < len(message):
|
||||
text = message[start:].strip()
|
||||
if text: # 只添加非空文本
|
||||
segments.append({
|
||||
'type': 'text',
|
||||
'data': {'text': text},
|
||||
'translated_text': text
|
||||
})
|
||||
break
|
||||
|
||||
# 添加CQ码前的文本
|
||||
if cq_start > start:
|
||||
text = message[start:cq_start].strip()
|
||||
if text: # 只添加非空文本
|
||||
segments.append({
|
||||
'type': 'text',
|
||||
'data': {'text': text},
|
||||
'translated_text': text
|
||||
})
|
||||
|
||||
# 查找CQ码的结束位置
|
||||
cq_end = message.find(']', cq_start)
|
||||
if cq_end == -1:
|
||||
# CQ码未闭合,作为普通文本处理
|
||||
text = message[cq_start:].strip()
|
||||
if text:
|
||||
segments.append({
|
||||
'type': 'text',
|
||||
'data': {'text': text},
|
||||
'translated_text': text
|
||||
})
|
||||
break
|
||||
|
||||
# 提取完整的CQ码并创建CQCode对象
|
||||
cq_code = message[cq_start:cq_end + 1]
|
||||
try:
|
||||
cq_obj = CQCode.from_cq_code(cq_code,reply = self.reply_message)
|
||||
# 设置必要的属性
|
||||
segments.append({
|
||||
'type': cq_obj.type,
|
||||
'data': cq_obj.params,
|
||||
'translated_text': cq_obj.translated_plain_text
|
||||
})
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"\033[1;31m[错误]\033[0m 处理CQ码失败: {str(e)}")
|
||||
print(f"CQ码内容: {cq_code}")
|
||||
print(f"当前消息属性:")
|
||||
print(f"- group_id: {self.group_id}")
|
||||
print(f"- user_id: {self.user_id}")
|
||||
print(f"- user_nickname: {self.user_nickname}")
|
||||
print(f"- group_name: {self.group_name}")
|
||||
print("详细错误信息:")
|
||||
print(traceback.format_exc())
|
||||
# 处理失败时,将CQ码作为普通文本处理
|
||||
segments.append({
|
||||
'type': 'text',
|
||||
'data': {'text': cq_code},
|
||||
'translated_text': cq_code
|
||||
})
|
||||
|
||||
start = cq_end + 1
|
||||
|
||||
# 检查是否只包含一个表情包CQ码
|
||||
if len(segments) == 1 and segments[0]['type'] == 'image':
|
||||
# 检查图片的 subtype 是否为 0(表情包)
|
||||
if segments[0]['data'].get('subtype') == '0':
|
||||
self.is_emoji = True
|
||||
|
||||
return segments
|
||||
|
||||
class Message_Thinking:
|
||||
"""消息思考类"""
|
||||
def __init__(self, message: Message,message_id: str):
|
||||
# 复制原始消息的基本属性
|
||||
self.group_id = message.group_id
|
||||
self.user_id = message.user_id
|
||||
self.user_nickname = message.user_nickname
|
||||
self.group_name = message.group_name
|
||||
|
||||
self.message_id = message_id
|
||||
|
||||
# 思考状态相关属性
|
||||
self.thinking_text = "正在思考..."
|
||||
self.time = int(time.time())
|
||||
|
||||
def update_to_message(self, done_message: Message) -> Message:
|
||||
"""更新为完整消息"""
|
||||
|
||||
return done_message
|
||||
|
||||
@property
|
||||
def processed_plain_text(self) -> str:
|
||||
"""获取处理后的文本"""
|
||||
return self.thinking_text
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"[思考中] 群:{self.group_id} 用户:{self.user_nickname} 时间:{self.time} 消息ID:{self.message_id}"
|
||||
|
||||
|
||||
class MessageSet:
|
||||
"""消息集合类,可以存储多个相关的消息"""
|
||||
def __init__(self, group_id: int, user_id: int, message_id: str):
|
||||
self.group_id = group_id
|
||||
self.user_id = user_id
|
||||
self.message_id = message_id
|
||||
self.messages: List[Message] = []
|
||||
self.time = round(time.time(), 2)
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""添加消息到集合"""
|
||||
self.messages.append(message)
|
||||
# 按时间排序
|
||||
self.messages.sort(key=lambda x: x.time)
|
||||
|
||||
def get_message_by_index(self, index: int) -> Optional[Message]:
|
||||
"""通过索引获取消息"""
|
||||
if 0 <= index < len(self.messages):
|
||||
return self.messages[index]
|
||||
return None
|
||||
|
||||
def get_message_by_time(self, target_time: float) -> Optional[Message]:
|
||||
"""获取最接近指定时间的消息"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
# 使用二分查找找到最接近的消息
|
||||
left, right = 0, len(self.messages) - 1
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
if self.messages[mid].time < target_time:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
return self.messages[left]
|
||||
|
||||
def get_latest_message(self) -> Optional[Message]:
|
||||
"""获取最新的消息"""
|
||||
return self.messages[-1] if self.messages else None
|
||||
|
||||
def get_earliest_message(self) -> Optional[Message]:
|
||||
"""获取最早的消息"""
|
||||
return self.messages[0] if self.messages else None
|
||||
|
||||
def get_all_messages(self) -> List[Message]:
|
||||
"""获取所有消息"""
|
||||
return self.messages.copy()
|
||||
|
||||
def get_message_count(self) -> int:
|
||||
"""获取消息数量"""
|
||||
return len(self.messages)
|
||||
|
||||
def clear_messages(self) -> None:
|
||||
"""清空所有消息"""
|
||||
self.messages.clear()
|
||||
|
||||
def remove_message(self, message: Message) -> bool:
|
||||
"""移除指定消息"""
|
||||
if message in self.messages:
|
||||
self.messages.remove(message)
|
||||
return True
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"MessageSet(id={self.message_id}, count={len(self.messages)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def processed_plain_text(self) -> str:
|
||||
"""获取所有消息的文本内容"""
|
||||
return "\n".join(msg.processed_plain_text for msg in self.messages if msg.processed_plain_text)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
322
src/plugins/chat/message_send_control.py
Normal file
322
src/plugins/chat/message_send_control.py
Normal file
@@ -0,0 +1,322 @@
|
||||
from typing import Union, List, Optional, Deque, Dict
|
||||
from nonebot.adapters.onebot.v11 import Bot, MessageSegment
|
||||
import asyncio
|
||||
import random
|
||||
from .message import Message, Message_Thinking, MessageSet
|
||||
from .cq_code import CQCode
|
||||
from collections import deque
|
||||
import time
|
||||
from .storage import MessageStorage # 添加这行导入
|
||||
|
||||
|
||||
class SendTemp:
|
||||
"""单个群组的临时消息队列管理器"""
|
||||
def __init__(self, group_id: int, max_size: int = 100):
|
||||
self.group_id = group_id
|
||||
self.max_size = max_size
|
||||
self.messages: Deque[Union[Message, Message_Thinking]] = deque(maxlen=max_size)
|
||||
self.last_send_time = 0
|
||||
|
||||
def add(self, message: Message) -> None:
|
||||
"""按时间顺序添加消息到队列"""
|
||||
if not self.messages:
|
||||
self.messages.append(message)
|
||||
return
|
||||
|
||||
# 按时间顺序插入
|
||||
if message.time >= self.messages[-1].time:
|
||||
self.messages.append(message)
|
||||
return
|
||||
|
||||
# 使用二分查找找到合适的插入位置
|
||||
messages_list = list(self.messages)
|
||||
left, right = 0, len(messages_list)
|
||||
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
if messages_list[mid].time < message.time:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
# 重建消息队列,保持时间顺序
|
||||
new_messages = deque(maxlen=self.max_size)
|
||||
new_messages.extend(messages_list[:left])
|
||||
new_messages.append(message)
|
||||
new_messages.extend(messages_list[left:])
|
||||
self.messages = new_messages
|
||||
def get_earliest_message(self) -> Optional[Message]:
|
||||
"""获取时间最早的消息"""
|
||||
message = self.messages.popleft() if self.messages else None
|
||||
# 如果是思考中的消息且思考时间不够,重新加入队列
|
||||
# if (isinstance(message, Message_Thinking) and
|
||||
# time.time() - message.start_time < 2): # 最少思考2秒
|
||||
# self.messages.appendleft(message)
|
||||
# return None
|
||||
return message
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空队列"""
|
||||
self.messages.clear()
|
||||
|
||||
def get_all(self, group_id: Optional[int] = None) -> List[Union[Message, Message_Thinking]]:
|
||||
"""获取所有待发送的消息"""
|
||||
if group_id is None:
|
||||
return list(self.messages)
|
||||
return [msg for msg in self.messages if msg.group_id == group_id]
|
||||
|
||||
def peek_next(self) -> Optional[Union[Message, Message_Thinking]]:
|
||||
"""查看下一条要发送的消息(不移除)"""
|
||||
return self.messages[0] if self.messages else None
|
||||
|
||||
def has_messages(self) -> bool:
|
||||
"""检查是否有待发送的消息"""
|
||||
return bool(self.messages)
|
||||
|
||||
def count(self, group_id: Optional[int] = None) -> int:
|
||||
"""获取待发送消息数量"""
|
||||
if group_id is None:
|
||||
return len(self.messages)
|
||||
return len([msg for msg in self.messages if msg.group_id == group_id])
|
||||
|
||||
def get_last_send_time(self) -> float:
|
||||
"""获取最后一次发送时间"""
|
||||
return self.last_send_time
|
||||
|
||||
def update_send_time(self):
|
||||
"""更新最后发送时间"""
|
||||
self.last_send_time = time.time()
|
||||
|
||||
class SendTempContainer:
|
||||
"""管理所有群组的消息缓存容器"""
|
||||
def __init__(self):
|
||||
self.temp_queues: Dict[int, SendTemp] = {}
|
||||
|
||||
def get_queue(self, group_id: int) -> SendTemp:
|
||||
"""获取或创建群组的消息队列"""
|
||||
if group_id not in self.temp_queues:
|
||||
self.temp_queues[group_id] = SendTemp(group_id)
|
||||
return self.temp_queues[group_id]
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""添加消息到对应群组的队列"""
|
||||
queue = self.get_queue(message.group_id)
|
||||
queue.add(message)
|
||||
|
||||
def get_group_messages(self, group_id: int) -> List[Union[Message, Message_Thinking]]:
|
||||
"""获取指定群组的所有待发送消息"""
|
||||
queue = self.get_queue(group_id)
|
||||
return queue.get_all()
|
||||
|
||||
def has_messages(self, group_id: int) -> bool:
|
||||
"""检查指定群组是否有待发送消息"""
|
||||
queue = self.get_queue(group_id)
|
||||
return queue.has_messages()
|
||||
|
||||
def get_all_groups(self) -> List[int]:
|
||||
"""获取所有有待发送消息的群组ID"""
|
||||
return list(self.temp_queues.keys())
|
||||
|
||||
def update_thinking_message(self, message_obj: Union[Message, MessageSet]) -> bool:
|
||||
"""更新思考中的消息
|
||||
|
||||
Args:
|
||||
message_obj: 要更新的消息对象,可以是单条消息或消息组
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
"""
|
||||
queue = self.get_queue(message_obj.group_id)
|
||||
|
||||
# 使用列表解析找到匹配的消息索引
|
||||
matching_indices = [
|
||||
i for i, msg in enumerate(queue.messages)
|
||||
if msg.message_id == message_obj.message_id
|
||||
]
|
||||
|
||||
if not matching_indices:
|
||||
return False
|
||||
|
||||
index = matching_indices[0] # 获取第一个匹配的索引
|
||||
|
||||
# 将消息转换为列表以便修改
|
||||
messages = list(queue.messages)
|
||||
|
||||
# 根据消息类型处理
|
||||
if isinstance(message_obj, MessageSet):
|
||||
messages.pop(index)
|
||||
# 在原位置插入新消息组
|
||||
for i, single_message in enumerate(message_obj.messages):
|
||||
messages.insert(index + i, single_message)
|
||||
# print(f"\033[1;34m[调试]\033[0m 添加消息组中的第{i+1}条消息: {single_message}")
|
||||
else:
|
||||
# 直接替换原消息
|
||||
messages[index] = message_obj
|
||||
# print(f"\033[1;34m[调试]\033[0m 已更新消息: {message_obj}")
|
||||
|
||||
# 重建队列
|
||||
queue.messages.clear()
|
||||
for msg in messages:
|
||||
queue.messages.append(msg)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class MessageSendControl:
|
||||
"""消息发送控制器"""
|
||||
def __init__(self):
|
||||
self.typing_speed = (0.1, 0.3) # 每个字符的打字时间范围(秒)
|
||||
self.message_interval = (0.5, 1) # 多条消息间的间隔时间范围(秒)
|
||||
self.max_retry = 3 # 最大重试次数
|
||||
self.send_temp_container = SendTempContainer()
|
||||
self._running = True
|
||||
self._paused = False
|
||||
self._current_bot = None
|
||||
self.storage = MessageStorage() # 添加存储实例
|
||||
|
||||
def set_bot(self, bot: Bot):
|
||||
"""设置当前bot实例"""
|
||||
self._current_bot = bot
|
||||
|
||||
async def start_processor(self, bot: Bot):
|
||||
"""启动消息处理器"""
|
||||
self._current_bot = bot
|
||||
|
||||
while self._running:
|
||||
await asyncio.sleep(0.5)
|
||||
# 处理所有群组的消息队列
|
||||
for group_id in self.send_temp_container.get_all_groups():
|
||||
queue = self.send_temp_container.get_queue(group_id)
|
||||
if queue.has_messages():
|
||||
message = queue.peek_next()
|
||||
# print(f"\033[1;34m[调试]\033[0m 查看最早的消息: {message}")
|
||||
if message:
|
||||
if isinstance(message, Message_Thinking):
|
||||
# 如果是思考中的消息,检查是否需要继续等待
|
||||
# message.update_thinking_time()
|
||||
thinking_time = time.time() - message.time
|
||||
if thinking_time < 60: # 最少思考2秒
|
||||
if int(thinking_time) % 10 == 0:
|
||||
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{thinking_time:.1f}秒")
|
||||
continue
|
||||
else:
|
||||
print(f"\033[1;34m[调试]\033[0m 思考消息超时,移除")
|
||||
queue.get_earliest_message() # 移除超时的思考消息
|
||||
|
||||
elif isinstance(message, Message):
|
||||
message = queue.get_earliest_message()
|
||||
if message and message.processed_plain_text:
|
||||
print(f"- 群组: {group_id} - 内容: {message.processed_plain_text}")
|
||||
|
||||
cost_time = round(time.time(), 2) - message.time
|
||||
# print(f"\033[1;34m[调试]\033[0m 消息发送111111时间: {cost_time}秒")
|
||||
if cost_time > 40:
|
||||
message.processed_plain_text = CQCode.create_reply_cq(message.message_based_id) + message.processed_plain_text
|
||||
|
||||
|
||||
|
||||
|
||||
await self._current_bot.send_group_msg(
|
||||
group_id=group_id,
|
||||
message=str(message.processed_plain_text),
|
||||
auto_escape=False
|
||||
)
|
||||
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||
print(f"\033[1;32m群 {group_id} 消息, 用户 麦麦, 时间: {current_time}:\033[0m {str(message.processed_plain_text)}")
|
||||
await self.storage.store_message(message, None)
|
||||
|
||||
queue.update_send_time()
|
||||
|
||||
if queue.has_messages():
|
||||
await asyncio.sleep(
|
||||
random.uniform(
|
||||
self.message_interval[0],
|
||||
self.message_interval[1]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def process_group_queue(self, bot: Bot, group_id: int) -> None:
|
||||
"""处理指定群组的消息队列"""
|
||||
queue = self.send_temp_container.get_queue(group_id)
|
||||
while queue.has_messages():
|
||||
message = queue.get_earliest_message()
|
||||
if message and message.processed_plain_text:
|
||||
await self.send_message(
|
||||
bot=bot,
|
||||
group_id=group_id,
|
||||
content=message.processed_plain_text
|
||||
)
|
||||
queue.update_send_time()
|
||||
|
||||
if queue.has_messages():
|
||||
await asyncio.sleep(
|
||||
random.uniform(self.message_interval[0], self.message_interval[1])
|
||||
)
|
||||
|
||||
async def process_all_queues(self, bot: Bot) -> None:
|
||||
"""处理所有群组的消息队列"""
|
||||
if not self._running or self._paused:
|
||||
return
|
||||
|
||||
for group_id in self.send_temp_container.get_all_groups():
|
||||
await self.process_group_queue(bot, group_id)
|
||||
|
||||
async def send_temp_message(self,
|
||||
bot: Bot,
|
||||
group_id: int,
|
||||
message: Union[Message, Message_Thinking],
|
||||
with_emoji: bool = False,
|
||||
emoji_path: Optional[str] = None) -> bool:
|
||||
"""
|
||||
发送单个临时消息
|
||||
Args:
|
||||
bot: Bot实例
|
||||
group_id: 群组ID
|
||||
message: Message对象
|
||||
with_emoji: 是否带表情
|
||||
emoji_path: 表情图片路径
|
||||
Returns:
|
||||
bool: 发送是否成功
|
||||
"""
|
||||
try:
|
||||
if with_emoji and emoji_path:
|
||||
return await self.send_with_emoji(
|
||||
bot=bot,
|
||||
group_id=group_id,
|
||||
text_content=message.processed_plain_text,
|
||||
emoji_path=emoji_path
|
||||
)
|
||||
else:
|
||||
return await self.send_message(
|
||||
bot=bot,
|
||||
group_id=group_id,
|
||||
content=message.processed_plain_text
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 发送临时消息失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def set_typing_speed(self, min_speed: float, max_speed: float):
|
||||
"""设置打字速度范围"""
|
||||
self.typing_speed = (min_speed, max_speed)
|
||||
|
||||
def set_message_interval(self, min_interval: float, max_interval: float):
|
||||
"""设置消息间隔范围"""
|
||||
self.message_interval = (min_interval, max_interval)
|
||||
|
||||
def pause(self):
|
||||
"""暂停消息处理"""
|
||||
self._paused = True
|
||||
|
||||
def resume(self):
|
||||
"""恢复消息处理"""
|
||||
self._paused = False
|
||||
|
||||
def stop(self):
|
||||
"""停止消息处理"""
|
||||
self._running = False
|
||||
|
||||
# 创建全局实例
|
||||
message_sender = MessageSendControl()
|
||||
264
src/plugins/chat/message_stream.py
Normal file
264
src/plugins/chat/message_stream.py
Normal file
@@ -0,0 +1,264 @@
|
||||
from typing import List, Optional, Dict
|
||||
from .message import Message
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
class MessageStream:
|
||||
"""单个群组的消息流容器"""
|
||||
def __init__(self, group_id: int, max_size: int = 1000):
|
||||
self.group_id = group_id
|
||||
self.messages = deque(maxlen=max_size)
|
||||
self.max_size = max_size
|
||||
self.last_save_time = time.time()
|
||||
|
||||
# 确保日志目录存在
|
||||
self.log_dir = os.path.join("log", str(self.group_id))
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
|
||||
# 启动自动保存任务
|
||||
asyncio.create_task(self._auto_save())
|
||||
|
||||
async def _auto_save(self):
|
||||
"""每30秒自动保存一次消息记录"""
|
||||
while True:
|
||||
await asyncio.sleep(30) # 等待30秒
|
||||
await self.save_to_log()
|
||||
|
||||
async def save_to_log(self):
|
||||
"""将消息保存到日志文件"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
# 只有有新消息时才保存
|
||||
if not self.messages or self.last_save_time == current_time:
|
||||
return
|
||||
|
||||
# 生成日志文件名 (使用当前日期)
|
||||
date_str = time.strftime("%Y-%m-%d", time.localtime(current_time))
|
||||
log_file = os.path.join(self.log_dir, f"chat_{date_str}.log")
|
||||
|
||||
# 获取需要保存的新消息
|
||||
new_messages = [
|
||||
msg for msg in self.messages
|
||||
if msg.time > self.last_save_time
|
||||
]
|
||||
|
||||
if not new_messages:
|
||||
return
|
||||
|
||||
# 将消息转换为可序列化的格式
|
||||
message_logs = []
|
||||
for msg in new_messages:
|
||||
message_logs.append({
|
||||
"time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(msg.time)),
|
||||
"user_id": msg.user_id,
|
||||
"user_nickname": msg.user_nickname,
|
||||
"message_id": msg.message_id,
|
||||
"raw_message": msg.raw_message,
|
||||
"processed_text": msg.processed_plain_text
|
||||
})
|
||||
|
||||
# 追加写入日志文件
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
for log in message_logs:
|
||||
f.write(json.dumps(log, ensure_ascii=False) + "\n")
|
||||
|
||||
self.last_save_time = current_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 保存群 {self.group_id} 的消息日志失败: {str(e)}")
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""按时间顺序添加新消息到队列
|
||||
|
||||
使用改进的二分查找算法来保持消息的时间顺序,同时优化内存使用。
|
||||
|
||||
Args:
|
||||
message: Message对象,要添加的新消息
|
||||
"""
|
||||
|
||||
# 空队列或消息应该添加到末尾的情况
|
||||
if (not self.messages or
|
||||
message.time >= self.messages[-1].time):
|
||||
self.messages.append(message)
|
||||
return
|
||||
|
||||
# 消息应该添加到开头的情况
|
||||
if message.time <= self.messages[0].time:
|
||||
self.messages.appendleft(message)
|
||||
return
|
||||
|
||||
# 使用二分查找在现有队列中找到合适的插入位置
|
||||
left, right = 0, len(self.messages) - 1
|
||||
while left <= right:
|
||||
mid = (left + right) // 2
|
||||
if self.messages[mid].time < message.time:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid - 1
|
||||
|
||||
temp = list(self.messages)
|
||||
temp.insert(left, message)
|
||||
|
||||
# 如果超出最大长度,移除多余的消息
|
||||
if len(temp) > self.max_size:
|
||||
temp = temp[-self.max_size:]
|
||||
|
||||
# 重建队列
|
||||
self.messages = deque(temp, maxlen=self.max_size)
|
||||
|
||||
async def get_recent_messages_from_db(self, count: int = 10) -> List[Message]:
|
||||
"""从数据库中获取最近的消息记录
|
||||
|
||||
Args:
|
||||
count: 需要获取的消息数量
|
||||
|
||||
Returns:
|
||||
List[Message]: 最近的消息列表
|
||||
"""
|
||||
try:
|
||||
from ...common.database import Database
|
||||
db = Database.get_instance()
|
||||
|
||||
# 从数据库中查询最近的消息
|
||||
recent_messages = list(db.db.messages.find(
|
||||
{"group_id": self.group_id},
|
||||
{
|
||||
"time": 1,
|
||||
"user_id": 1,
|
||||
"user_nickname": 1,
|
||||
"message_id": 1,
|
||||
"raw_message": 1,
|
||||
"processed_text": 1
|
||||
}
|
||||
).sort("time", -1).limit(count))
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
# 转换为 Message 对象
|
||||
from .message import Message
|
||||
messages = []
|
||||
for msg_data in recent_messages:
|
||||
msg = Message(
|
||||
time=msg_data["time"],
|
||||
user_id=msg_data["user_id"],
|
||||
user_nickname=msg_data.get("user_nickname", ""),
|
||||
message_id=msg_data["message_id"],
|
||||
raw_message=msg_data["raw_message"],
|
||||
processed_plain_text=msg_data.get("processed_text", ""),
|
||||
group_id=self.group_id
|
||||
)
|
||||
messages.append(msg)
|
||||
|
||||
return list(reversed(messages)) # 返回按时间正序的消息
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 从数据库获取群 {self.group_id} 的最近消息记录失败: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_recent_messages(self, count: int = 10) -> List[Message]:
|
||||
"""获取最近的n条消息(从内存队列)"""
|
||||
print(f"\033[1;34m[调试]\033[0m 从内存获取群 {self.group_id} 的最近{count}条消息记录")
|
||||
return list(self.messages)[-count:]
|
||||
|
||||
def get_messages_in_timerange(self,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None) -> List[Message]:
|
||||
"""获取时间范围内的消息"""
|
||||
if start_time is None:
|
||||
start_time = time.time() - 3600
|
||||
if end_time is None:
|
||||
end_time = time.time()
|
||||
|
||||
return [
|
||||
msg for msg in self.messages
|
||||
if start_time <= msg.time <= end_time
|
||||
]
|
||||
|
||||
def get_user_messages(self, user_id: int, count: int = 10) -> List[Message]:
|
||||
"""获取特定用户的最近消息"""
|
||||
user_messages = [msg for msg in self.messages if msg.user_id == user_id]
|
||||
return user_messages[-count:]
|
||||
|
||||
def clear_old_messages(self, hours: int = 24) -> None:
|
||||
"""清理旧消息"""
|
||||
cutoff_time = time.time() - (hours * 3600)
|
||||
self.messages = deque(
|
||||
[msg for msg in self.messages if msg.time > cutoff_time],
|
||||
maxlen=self.max_size
|
||||
)
|
||||
|
||||
class MessageStreamContainer:
|
||||
"""管理所有群组的消息流容器"""
|
||||
def __init__(self, max_size: int = 1000):
|
||||
self.streams: Dict[int, MessageStream] = {}
|
||||
self.max_size = max_size
|
||||
|
||||
async def save_all_logs(self):
|
||||
"""保存所有群组的消息日志"""
|
||||
for stream in self.streams.values():
|
||||
await stream.save_to_log()
|
||||
|
||||
def add_message(self, message: Message) -> None:
|
||||
"""添加消息到对应群组的消息流"""
|
||||
if not message.group_id:
|
||||
return
|
||||
|
||||
if message.group_id not in self.streams:
|
||||
self.streams[message.group_id] = MessageStream(message.group_id, self.max_size)
|
||||
|
||||
self.streams[message.group_id].add_message(message)
|
||||
|
||||
def get_stream(self, group_id: int) -> Optional[MessageStream]:
|
||||
"""获取特定群组的消息流"""
|
||||
return self.streams.get(group_id)
|
||||
|
||||
def get_all_streams(self) -> Dict[int, MessageStream]:
|
||||
"""获取所有群组的消息流"""
|
||||
return self.streams
|
||||
|
||||
def clear_old_messages(self, hours: int = 24) -> None:
|
||||
"""清理所有群组的旧消息"""
|
||||
for stream in self.streams.values():
|
||||
stream.clear_old_messages(hours)
|
||||
|
||||
def get_group_stats(self, group_id: int) -> Dict:
|
||||
"""获取群组的消息统计信息"""
|
||||
stream = self.streams.get(group_id)
|
||||
if not stream:
|
||||
return {
|
||||
"total_messages": 0,
|
||||
"unique_users": 0,
|
||||
"active_hours": [],
|
||||
"most_active_user": None
|
||||
}
|
||||
|
||||
messages = stream.messages
|
||||
user_counts = {}
|
||||
hour_counts = {}
|
||||
|
||||
for msg in messages:
|
||||
user_counts[msg.user_id] = user_counts.get(msg.user_id, 0) + 1
|
||||
hour = datetime.fromtimestamp(msg.time).hour
|
||||
hour_counts[hour] = hour_counts.get(hour, 0) + 1
|
||||
|
||||
most_active_user = max(user_counts.items(), key=lambda x: x[1])[0] if user_counts else None
|
||||
active_hours = sorted(
|
||||
hour_counts.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:5]
|
||||
|
||||
return {
|
||||
"total_messages": len(messages),
|
||||
"unique_users": len(user_counts),
|
||||
"active_hours": active_hours,
|
||||
"most_active_user": most_active_user
|
||||
}
|
||||
|
||||
# 创建全局实例
|
||||
message_stream_container = MessageStreamContainer()
|
||||
138
src/plugins/chat/message_visualizer.py
Normal file
138
src/plugins/chat/message_visualizer.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import subprocess
|
||||
import threading
|
||||
import queue
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
from .message import Message_Thinking
|
||||
|
||||
class MessageVisualizer:
|
||||
def __init__(self):
|
||||
self.process = None
|
||||
self.message_queue = queue.Queue()
|
||||
self.is_running = False
|
||||
self.content_file = "message_queue_content.txt"
|
||||
|
||||
def start(self):
|
||||
if self.process is None:
|
||||
# 创建用于显示的批处理文件
|
||||
with open("message_queue_window.bat", "w", encoding="utf-8") as f:
|
||||
f.write('@echo off\n')
|
||||
f.write('chcp 65001\n') # 设置UTF-8编码
|
||||
f.write('title Message Queue Visualizer\n')
|
||||
f.write('echo Waiting for message queue updates...\n')
|
||||
f.write(':loop\n')
|
||||
f.write('if exist "queue_update.txt" (\n')
|
||||
f.write(' type "queue_update.txt" > "message_queue_content.txt"\n')
|
||||
f.write(' del "queue_update.txt"\n')
|
||||
f.write(' cls\n')
|
||||
f.write(' type "message_queue_content.txt"\n')
|
||||
f.write(')\n')
|
||||
f.write('timeout /t 1 /nobreak >nul\n')
|
||||
f.write('goto loop\n')
|
||||
|
||||
# 清空内容文件
|
||||
with open(self.content_file, "w", encoding="utf-8") as f:
|
||||
f.write("")
|
||||
|
||||
# 启动新窗口
|
||||
startupinfo = subprocess.STARTUPINFO()
|
||||
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
self.process = subprocess.Popen(
|
||||
['cmd', '/c', 'start', 'message_queue_window.bat'],
|
||||
shell=True,
|
||||
startupinfo=startupinfo
|
||||
)
|
||||
self.is_running = True
|
||||
|
||||
# 启动处理线程
|
||||
threading.Thread(target=self._process_messages, daemon=True).start()
|
||||
|
||||
def _process_messages(self):
|
||||
while self.is_running:
|
||||
try:
|
||||
# 获取新消息
|
||||
text = self.message_queue.get(timeout=1)
|
||||
# 写入更新文件
|
||||
with open("queue_update.txt", "w", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"处理队列可视化内容时出错: {e}")
|
||||
|
||||
def update_content(self, send_temp_container):
|
||||
"""更新显示内容"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
display_text = f"Message Queue Status - {current_time}\n"
|
||||
display_text += "=" * 50 + "\n\n"
|
||||
|
||||
# 遍历所有群组的队列
|
||||
for group_id, queue in send_temp_container.temp_queues.items():
|
||||
display_text += f"\n{'='*20} 群组: {queue.group_id} {'='*20}\n"
|
||||
display_text += f"消息队列长度: {len(queue.messages)}\n"
|
||||
display_text += f"最后发送时间: {time.strftime('%H:%M:%S', time.localtime(queue.last_send_time))}\n"
|
||||
display_text += "\n消息队列内容:\n"
|
||||
|
||||
# 显示队列中的消息
|
||||
if not queue.messages:
|
||||
display_text += " [空队列]\n"
|
||||
else:
|
||||
for i, msg in enumerate(queue.messages):
|
||||
msg_time = time.strftime("%H:%M:%S", time.localtime(msg.time))
|
||||
display_text += f"\n--- 消息 {i+1} ---\n"
|
||||
|
||||
if isinstance(msg, Message_Thinking):
|
||||
display_text += f"类型: \033[1;33m思考中消息\033[0m\n"
|
||||
display_text += f"时间: {msg_time}\n"
|
||||
display_text += f"消息ID: {msg.message_id}\n"
|
||||
display_text += f"群组: {msg.group_id}\n"
|
||||
display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
|
||||
display_text += f"内容: {msg.thinking_text}\n"
|
||||
display_text += f"思考时间: {msg.thinking_time}秒\n"
|
||||
else:
|
||||
display_text += f"类型: 普通消息\n"
|
||||
display_text += f"时间: {msg_time}\n"
|
||||
display_text += f"消息ID: {msg.message_id}\n"
|
||||
display_text += f"群组: {msg.group_id}\n"
|
||||
display_text += f"用户: {msg.user_nickname}({msg.user_id})\n"
|
||||
if hasattr(msg, 'is_emoji') and msg.is_emoji:
|
||||
display_text += f"内容: [表情包消息]\n"
|
||||
else:
|
||||
# 显示原始消息和处理后的消息
|
||||
display_text += f"原始内容: {msg.raw_message[:50]}...\n"
|
||||
display_text += f"处理后内容: {msg.processed_plain_text[:50]}...\n"
|
||||
|
||||
if msg.reply_message:
|
||||
display_text += f"回复消息: {str(msg.reply_message)[:50]}...\n"
|
||||
|
||||
display_text += f"\n{'-' * 50}\n"
|
||||
|
||||
# 添加统计信息
|
||||
display_text += "\n总体统计:\n"
|
||||
display_text += f"活跃群组数: {len(send_temp_container.temp_queues)}\n"
|
||||
total_messages = sum(len(q.messages) for q in send_temp_container.temp_queues.values())
|
||||
display_text += f"总消息数: {total_messages}\n"
|
||||
thinking_messages = sum(
|
||||
sum(1 for msg in q.messages if isinstance(msg, Message_Thinking))
|
||||
for q in send_temp_container.temp_queues.values()
|
||||
)
|
||||
display_text += f"思考中消息数: {thinking_messages}\n"
|
||||
|
||||
self.message_queue.put(display_text)
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
if self.process:
|
||||
self.process.terminate()
|
||||
self.process = None
|
||||
# 清理文件
|
||||
for file in ["message_queue_window.bat", "message_queue_content.txt", "queue_update.txt"]:
|
||||
if os.path.exists(file):
|
||||
os.remove(file)
|
||||
|
||||
# 创建全局单例
|
||||
message_visualizer = MessageVisualizer()
|
||||
193
src/plugins/chat/prompt_builder.py
Normal file
193
src/plugins/chat/prompt_builder.py
Normal file
@@ -0,0 +1,193 @@
|
||||
import time
|
||||
import random
|
||||
from dotenv import load_dotenv
|
||||
from ..schedule.schedule_generator import bot_schedule
|
||||
import os
|
||||
from .utils import get_embedding, combine_messages, get_recent_group_messages
|
||||
from ...common.database import Database
|
||||
|
||||
# 获取当前文件的绝对路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||
load_dotenv(os.path.join(root_dir, '.env'))
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
def __init__(self):
|
||||
self.prompt_built = ''
|
||||
self.activate_messages = ''
|
||||
self.db = Database.get_instance()
|
||||
|
||||
def _build_prompt(self,
|
||||
message_txt: str,
|
||||
sender_name: str = "某人",
|
||||
relationship_value: float = 0.0,
|
||||
group_id: int = None) -> str:
|
||||
"""构建prompt
|
||||
|
||||
Args:
|
||||
message_txt: 消息文本
|
||||
sender_name: 发送者昵称
|
||||
relationship_value: 关系值
|
||||
group_id: 群组ID
|
||||
|
||||
Returns:
|
||||
str: 构建好的prompt
|
||||
"""
|
||||
#先禁用关系
|
||||
if 0 > 30:
|
||||
relation_prompt = "关系特别特别好,你很喜欢喜欢他"
|
||||
relation_prompt_2 = "热情发言或者回复"
|
||||
elif 0 <-20:
|
||||
relation_prompt = "关系很差,你很讨厌他"
|
||||
relation_prompt_2 = "骂他"
|
||||
else:
|
||||
relation_prompt = "关系一般"
|
||||
relation_prompt_2 = "发言或者回复"
|
||||
|
||||
#开始构建prompt
|
||||
|
||||
#日程构建
|
||||
current_date = time.strftime("%Y-%m-%d", time.localtime())
|
||||
current_time = time.strftime("%H:%M:%S", time.localtime())
|
||||
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task()
|
||||
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n'''
|
||||
|
||||
#知识构建
|
||||
prompt_info = ''
|
||||
prompt_info = self.get_prompt_info(message_txt)
|
||||
if prompt_info:
|
||||
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:
|
||||
\n{prompt_info}\n
|
||||
请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
|
||||
|
||||
#获取最新消息并计时
|
||||
chat_talking_prompt = ''
|
||||
|
||||
if group_id:
|
||||
# 从数据库获取最近消息
|
||||
message_objects = get_recent_group_messages(self.db, group_id)
|
||||
if message_objects:
|
||||
chat_talking_prompt = combine_messages(message_objects)
|
||||
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的最近{len(message_objects)}条消息记录")
|
||||
|
||||
#激活prompt构建
|
||||
activate_prompt = ''
|
||||
activate_prompt = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2} \n"
|
||||
|
||||
#人格选择
|
||||
prompt_personality = ''
|
||||
personality_choice = random.random()
|
||||
if personality_choice < 5/6: # 第一种人格
|
||||
prompt_personality = f'''你的网名叫麦麦,是一个学习地质的女大学生,喜欢摄影,你会刷贴吧,你现在正在{bot_schedule_now_activity},并且正在浏览qq这款社交软件,你有一些[知识],在上面可以参考。
|
||||
现在请你给出日常且口语化的回复,不要过多提及自身的背景,平淡一些,尽量简短一些,不要输出感叹号。如果有人说你是人机或者机器人,否定并攻击那个人。
|
||||
不要回复的太有逻辑或者太有条理。'''
|
||||
elif personality_choice < 1: # 第二种人格
|
||||
prompt_personality = f'''你的网名叫麦麦,是一个女大学生,你有黑色头发你会刷微博和小红书,你现在正在{bot_schedule_now_activity},并且正在浏览qq群,你有一些[知识],在上面可以参考。
|
||||
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些,不要输出感叹号。如果有人说你是人机或者机器人,否定。
|
||||
如果有人说他是机器人,肯定他。如果有人对你谈论涉及性的话题或言语挑逗,需要反驳,但其它情况请你表达自己的见解和观点。'''
|
||||
|
||||
#额外信息要求
|
||||
extra_info = '''回复平淡一些,简短一些,但是记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
|
||||
|
||||
#合并prompt
|
||||
prompt = ""
|
||||
prompt += f"{prompt_info}\n"
|
||||
prompt += f"{prompt_date}\n"
|
||||
prompt += f"{chat_talking_prompt}\n"
|
||||
prompt += f"{activate_prompt}\n"
|
||||
prompt += f"{prompt_personality}\n"
|
||||
prompt += f"{extra_info}\n"
|
||||
|
||||
return prompt
|
||||
|
||||
def get_prompt_info(self,message:str):
|
||||
related_info = ''
|
||||
if len(message) > 10:
|
||||
message_segments = [message[i:i+10] for i in range(0, len(message), 10)]
|
||||
for segment in message_segments:
|
||||
embedding = get_embedding(segment)
|
||||
related_info += self.get_info_from_db(embedding)
|
||||
|
||||
else:
|
||||
embedding = get_embedding(message)
|
||||
related_info += self.get_info_from_db(embedding)
|
||||
|
||||
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
|
||||
"""
|
||||
从知识库中查找与输入向量最相似的内容
|
||||
Args:
|
||||
query_embedding: 查询向量
|
||||
limit: 返回结果数量,默认为2
|
||||
threshold: 相似度阈值,默认为0.5
|
||||
Returns:
|
||||
str: 找到的相关信息,如果相似度低于阈值则返回空字符串
|
||||
"""
|
||||
if not query_embedding:
|
||||
return ''
|
||||
|
||||
# 使用余弦相似度计算
|
||||
pipeline = [
|
||||
{
|
||||
"$addFields": {
|
||||
"dotProduct": {
|
||||
"$reduce": {
|
||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||||
"initialValue": 0,
|
||||
"in": {
|
||||
"$add": [
|
||||
"$$value",
|
||||
{"$multiply": [
|
||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||
{"$arrayElemAt": [query_embedding, "$$this"]}
|
||||
]}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"magnitude1": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": "$embedding",
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
||||
}
|
||||
}
|
||||
},
|
||||
"magnitude2": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": query_embedding,
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$addFields": {
|
||||
"similarity": {
|
||||
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$match": {
|
||||
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
||||
}
|
||||
},
|
||||
{"$sort": {"similarity": -1}},
|
||||
{"$limit": limit},
|
||||
{"$project": {"content": 1, "similarity": 1}}
|
||||
]
|
||||
|
||||
results = list(self.db.db.knowledges.aggregate(pipeline))
|
||||
|
||||
if not results:
|
||||
return ''
|
||||
|
||||
# 返回所有找到的内容,用换行分隔
|
||||
return '\n'.join(str(result['content']) for result in results)
|
||||
|
||||
prompt_builder = PromptBuilder()
|
||||
200
src/plugins/chat/relationship_manager.py
Normal file
200
src/plugins/chat/relationship_manager.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import time
|
||||
from ...common.database import Database
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
from typing import Optional, Tuple
|
||||
import asyncio
|
||||
|
||||
class Impression:
|
||||
traits: str = None
|
||||
called: str = None
|
||||
know_time: float = None
|
||||
|
||||
relationship_value: float = None
|
||||
|
||||
class Relationship:
|
||||
user_id: int = None
|
||||
# impression: Impression = None
|
||||
# group_id: int = None
|
||||
# group_name: str = None
|
||||
gender: str = None
|
||||
age: int = None
|
||||
nickname: str = None
|
||||
relationship_value: float = None
|
||||
saved = False
|
||||
|
||||
def __init__(self, user_id: int, data=None, **kwargs):
|
||||
if isinstance(data, dict):
|
||||
# 如果输入是字典,使用字典解析
|
||||
self.user_id = data.get('user_id')
|
||||
self.gender = data.get('gender')
|
||||
self.age = data.get('age')
|
||||
self.nickname = data.get('nickname')
|
||||
self.relationship_value = data.get('relationship_value', 0.0)
|
||||
self.saved = data.get('saved', False)
|
||||
else:
|
||||
# 如果是直接传入属性值
|
||||
self.user_id = kwargs.get('user_id')
|
||||
self.gender = kwargs.get('gender')
|
||||
self.age = kwargs.get('age')
|
||||
self.nickname = kwargs.get('nickname')
|
||||
self.relationship_value = kwargs.get('relationship_value', 0.0)
|
||||
self.saved = kwargs.get('saved', False)
|
||||
|
||||
|
||||
|
||||
|
||||
class RelationshipManager:
|
||||
def __init__(self):
|
||||
self.relationships: dict[int, Relationship] = {} # user_id -> Relationship
|
||||
#保存 qq号,现在使用昵称,别称
|
||||
self.id_name_nickname_table: dict[str, str, list] = {} # name -> [nickname, nickname, ...]
|
||||
|
||||
async def update_relationship(self, user_id: int, data=None, **kwargs):
|
||||
# 检查是否在内存中已存在
|
||||
relationship = self.relationships.get(user_id)
|
||||
if relationship:
|
||||
# 如果存在,更新现有对象
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if hasattr(relationship, key) and value is not None:
|
||||
setattr(relationship, key, value)
|
||||
else:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(relationship, key) and value is not None:
|
||||
setattr(relationship, key, value)
|
||||
else:
|
||||
# 如果不存在,创建新对象
|
||||
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id, **kwargs)
|
||||
self.relationships[user_id] = relationship
|
||||
|
||||
# 保存到数据库
|
||||
await self.storage_relationship(relationship)
|
||||
relationship.saved = True
|
||||
|
||||
return relationship
|
||||
|
||||
async def update_relationship_value(self, user_id: int, **kwargs):
|
||||
# 检查是否在内存中已存在
|
||||
relationship = self.relationships.get(user_id)
|
||||
if relationship:
|
||||
for key, value in kwargs.items():
|
||||
if key == 'relationship_value':
|
||||
relationship.relationship_value += value
|
||||
await self.storage_relationship(relationship)
|
||||
relationship.saved = True
|
||||
return relationship
|
||||
else:
|
||||
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id} 不存在,无法更新")
|
||||
return None
|
||||
|
||||
|
||||
def get_relationship(self, user_id: int) -> Optional[Relationship]:
|
||||
"""获取用户关系对象"""
|
||||
if user_id in self.relationships:
|
||||
return self.relationships[user_id]
|
||||
else:
|
||||
return 0
|
||||
|
||||
async def load_relationship(self, data: dict) -> Relationship:
|
||||
"""从数据库加载或创建新的关系对象"""
|
||||
rela = Relationship(user_id=data['user_id'], data=data)
|
||||
rela.saved = True
|
||||
return rela
|
||||
|
||||
async def _start_relationship_manager(self):
|
||||
"""每5分钟自动保存一次关系数据"""
|
||||
db = Database.get_instance()
|
||||
# 获取所有关系记录
|
||||
all_relationships = db.db.relationships.find({})
|
||||
# 依次加载每条记录
|
||||
for data in all_relationships:
|
||||
user_id = data['user_id']
|
||||
relationship = await self.load_relationship(data)
|
||||
self.relationships[user_id] = relationship
|
||||
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
|
||||
|
||||
while True:
|
||||
print(f"\033[1;32m[关系管理]\033[0m 正在自动保存关系")
|
||||
await asyncio.sleep(300) # 等待300秒(5分钟)
|
||||
await self._save_all_relationships()
|
||||
|
||||
async def _save_all_relationships(self):
|
||||
"""将所有关系数据保存到数据库"""
|
||||
# 保存所有关系数据
|
||||
for relationship in self.relationships:
|
||||
if not relationship.saved:
|
||||
relationship.saved = True
|
||||
await self.storage_relationship(relationship)
|
||||
|
||||
async def storage_relationship(self,relationship: Relationship):
|
||||
"""
|
||||
将关系记录存储到数据库中
|
||||
"""
|
||||
user_id = relationship.user_id
|
||||
nickname = relationship.nickname
|
||||
relationship_value = relationship.relationship_value
|
||||
gender = relationship.gender
|
||||
age = relationship.age
|
||||
saved = relationship.saved
|
||||
|
||||
db = Database.get_instance()
|
||||
db.db.relationships.update_one(
|
||||
{'user_id': user_id},
|
||||
{'$set': {
|
||||
'nickname': nickname,
|
||||
'relationship_value': relationship_value,
|
||||
'gender': gender,
|
||||
'age': age,
|
||||
'saved': saved
|
||||
}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_nickname(bot: Bot, user_id: int, group_id: int = None) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
通过QQ API获取用户昵称
|
||||
"""
|
||||
|
||||
# 获取QQ昵称
|
||||
stranger_info = await bot.get_stranger_info(user_id=user_id)
|
||||
qq_nickname = stranger_info['nickname']
|
||||
|
||||
# 如果提供了群号,获取群昵称
|
||||
if group_id:
|
||||
try:
|
||||
member_info = await bot.get_group_member_info(
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
no_cache=True
|
||||
)
|
||||
group_nickname = member_info['card'] or None
|
||||
return qq_nickname, group_nickname
|
||||
except:
|
||||
return qq_nickname, None
|
||||
|
||||
return qq_nickname, None
|
||||
|
||||
def print_all_relationships(self):
|
||||
"""打印内存中所有的关系记录"""
|
||||
print("\n\033[1;32m[关系管理]\033[0m 当前内存中的所有关系:")
|
||||
print("=" * 50)
|
||||
|
||||
if not self.relationships:
|
||||
print("暂无关系记录")
|
||||
return
|
||||
|
||||
for user_id, relationship in self.relationships.items():
|
||||
print(f"用户ID: {user_id}")
|
||||
print(f"昵称: {relationship.nickname}")
|
||||
print(f"好感度: {relationship.relationship_value}")
|
||||
print("-" * 30)
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
relationship_manager = RelationshipManager()
|
||||
48
src/plugins/chat/storage.py
Normal file
48
src/plugins/chat/storage.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import Dict, List, Any, Optional
|
||||
import time
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
import asyncio
|
||||
from .message import Message
|
||||
from ...common.database import Database
|
||||
from .image_utils import storage_compress_image
|
||||
|
||||
class MessageStorage:
|
||||
def __init__(self):
|
||||
self.db = Database.get_instance()
|
||||
|
||||
async def store_message(self, message: Message, topic: Optional[str] = None) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
if not message.is_emoji:
|
||||
message_data = {
|
||||
"group_id": message.group_id,
|
||||
"user_id": message.user_id,
|
||||
"message_id": message.message_id,
|
||||
"raw_message": message.raw_message,
|
||||
"plain_text": message.plain_text,
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
"time": message.time,
|
||||
"user_nickname": message.user_nickname,
|
||||
"group_name": message.group_name,
|
||||
"topic": topic,
|
||||
}
|
||||
else:
|
||||
message_data = {
|
||||
"group_id": message.group_id,
|
||||
"user_id": message.user_id,
|
||||
"message_id": message.message_id,
|
||||
"raw_message": message.raw_message,
|
||||
"plain_text": message.plain_text,
|
||||
"processed_plain_text": '[表情包]',
|
||||
"time": message.time,
|
||||
"user_nickname": message.user_nickname,
|
||||
"group_name": message.group_name,
|
||||
"topic": topic,
|
||||
}
|
||||
|
||||
self.db.db.messages.insert_one(message_data)
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
|
||||
|
||||
# 如果需要其他存储相关的函数,可以在这里添加
|
||||
96
src/plugins/chat/topic_identifier.py
Normal file
96
src/plugins/chat/topic_identifier.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Optional, Dict, List
|
||||
from openai import OpenAI
|
||||
from .message import Message
|
||||
from .config import global_config, llm_config
|
||||
import jieba
|
||||
|
||||
class TopicIdentifier:
|
||||
def __init__(self):
|
||||
self.client = OpenAI(
|
||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
||||
)
|
||||
|
||||
def identify_topic_llm(self, text: str) -> Optional[str]:
|
||||
"""识别消息主题"""
|
||||
|
||||
prompt = f"""判断这条消息的主题,如果没有明显主题请回复"无主题",要求:
|
||||
1. 主题通常2-4个字,必须简短,要求精准概括,不要太具体。
|
||||
2. 建议给出多个主题,之间用英文逗号分割。只输出主题本身就好,不要有前后缀。
|
||||
|
||||
消息内容:{text}"""
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model="Pro/deepseek-ai/DeepSeek-V3",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.8,
|
||||
max_tokens=10
|
||||
)
|
||||
|
||||
if not response or not response.choices:
|
||||
print(f"\033[1;31m[错误]\033[0m OpenAI API 返回为空")
|
||||
return None
|
||||
|
||||
# 从 OpenAI API 响应中获取第一个选项的消息内容,并去除首尾空白字符
|
||||
topic = response.choices[0].message.content.strip() if response.choices[0].message.content else None
|
||||
|
||||
if topic == "无主题":
|
||||
return None
|
||||
else:
|
||||
# print(f"[主题分析结果]{text[:20]}... : {topic}")
|
||||
split_topic = self.parse_topic(topic)
|
||||
return split_topic
|
||||
|
||||
|
||||
def parse_topic(self, topic: str) -> List[str]:
|
||||
"""解析主题,返回主题列表"""
|
||||
if not topic or topic == "无主题":
|
||||
return []
|
||||
return [t.strip() for t in topic.split(",") if t.strip()]
|
||||
|
||||
def identify_topic_jieba(self, text: str) -> Optional[str]:
|
||||
"""使用jieba识别主题"""
|
||||
words = jieba.lcut(text)
|
||||
# 去除停用词和标点符号
|
||||
stop_words = {
|
||||
'的', '了', '和', '是', '就', '都', '而', '及', '与', '这', '那', '但', '然', '却',
|
||||
'因为', '所以', '如果', '虽然', '一个', '我', '你', '他', '她', '它', '我们', '你们',
|
||||
'他们', '在', '有', '个', '把', '被', '让', '给', '从', '向', '到', '又', '也', '很',
|
||||
'啊', '吧', '呢', '吗', '呀', '哦', '哈', '么', '嘛', '啦', '哎', '唉', '哇', '嗯',
|
||||
'哼', '哪', '什么', '怎么', '为什么', '怎样', '如何', '什么样', '这样', '那样', '这么',
|
||||
'那么', '多少', '几', '谁', '哪里', '哪儿', '什么时候', '何时', '为何', '怎么办',
|
||||
'怎么样', '这些', '那些', '一些', '一点', '一下', '一直', '一定', '一般', '一样',
|
||||
'一会儿', '一边', '一起',
|
||||
# 添加更多量词
|
||||
'个', '只', '条', '张', '片', '块', '本', '册', '页', '幅', '面', '篇', '份',
|
||||
'朵', '颗', '粒', '座', '幢', '栋', '间', '层', '家', '户', '位', '名', '群',
|
||||
'双', '对', '打', '副', '套', '批', '组', '串', '包', '箱', '袋', '瓶', '罐',
|
||||
# 添加更多介词
|
||||
'按', '按照', '把', '被', '比', '比如', '除', '除了', '当', '对', '对于',
|
||||
'根据', '关于', '跟', '和', '将', '经', '经过', '靠', '连', '论', '通过',
|
||||
'同', '往', '为', '为了', '围绕', '于', '由', '由于', '与', '在', '沿', '沿着',
|
||||
'依', '依照', '以', '因', '因为', '用', '由', '与', '自', '自从'
|
||||
}
|
||||
|
||||
# 过滤掉停用词和标点符号,只保留名词和动词
|
||||
filtered_words = []
|
||||
for word in words:
|
||||
if word not in stop_words and not word.strip() in {
|
||||
'。', ',', '、', ':', ';', '!', '?', '"', '"', ''', ''',
|
||||
'(', ')', '【', '】', '《', '》', '…', '—', '·', '、', '~',
|
||||
'~', '+', '=', '-'
|
||||
}:
|
||||
filtered_words.append(word)
|
||||
|
||||
# 统计词频
|
||||
word_freq = {}
|
||||
for word in filtered_words:
|
||||
word_freq[word] = word_freq.get(word, 0) + 1
|
||||
|
||||
# 按词频排序,取前3个
|
||||
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
top_words = [word for word, freq in sorted_words[:3]]
|
||||
|
||||
return top_words if top_words else None
|
||||
|
||||
topic_identifier = TopicIdentifier()
|
||||
115
src/plugins/chat/utils.py
Normal file
115
src/plugins/chat/utils.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import time
|
||||
from typing import List
|
||||
from .message import Message
|
||||
import requests
|
||||
import numpy as np
|
||||
from .config import llm_config
|
||||
|
||||
def combine_messages(messages: List[Message]) -> str:
|
||||
"""将消息列表组合成格式化的字符串
|
||||
|
||||
Args:
|
||||
messages: Message对象列表
|
||||
|
||||
Returns:
|
||||
str: 格式化后的消息字符串
|
||||
"""
|
||||
result = ""
|
||||
for message in messages:
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
|
||||
name = message.user_nickname or f"用户{message.user_id}"
|
||||
content = message.processed_plain_text or message.plain_text
|
||||
|
||||
result += f"[{time_str}] {name}: {content}\n"
|
||||
|
||||
return result
|
||||
|
||||
def is_mentioned_bot_in_message(message: Message) -> bool:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = ['麦麦', '麦哲伦']
|
||||
for keyword in keywords:
|
||||
if keyword in message.processed_plain_text:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_mentioned_bot_in_txt(message: str) -> bool:
|
||||
"""检查消息是否提到了机器人"""
|
||||
keywords = ['麦麦', '麦哲伦']
|
||||
for keyword in keywords:
|
||||
if keyword in message:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_embedding(text):
|
||||
url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
payload = {
|
||||
"model": "BAAI/bge-m3",
|
||||
"input": text,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"API请求失败: {response.status_code}")
|
||||
print(f"错误信息: {response.text}")
|
||||
return None
|
||||
|
||||
return response.json()['data'][0]['embedding']
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
return dot_product / (norm1 * norm2)
|
||||
|
||||
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
"""从数据库获取群组最近的消息记录
|
||||
|
||||
Args:
|
||||
db: Database实例
|
||||
group_id: 群组ID
|
||||
limit: 获取消息数量,默认12条
|
||||
|
||||
Returns:
|
||||
list: Message对象列表,按时间正序排列
|
||||
"""
|
||||
|
||||
# 从数据库获取最近消息
|
||||
recent_messages = list(db.db.messages.find(
|
||||
{"group_id": group_id},
|
||||
{
|
||||
"time": 1,
|
||||
"user_id": 1,
|
||||
"user_nickname": 1,
|
||||
"message_id": 1,
|
||||
"raw_message": 1,
|
||||
"processed_text": 1
|
||||
}
|
||||
).sort("time", -1).limit(limit))
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
# 转换为 Message对象列表
|
||||
from .message import Message
|
||||
message_objects = []
|
||||
for msg_data in recent_messages:
|
||||
msg = Message(
|
||||
time=msg_data["time"],
|
||||
user_id=msg_data["user_id"],
|
||||
user_nickname=msg_data.get("user_nickname", ""),
|
||||
message_id=msg_data["message_id"],
|
||||
raw_message=msg_data["raw_message"],
|
||||
processed_plain_text=msg_data.get("processed_text", ""),
|
||||
group_id=group_id
|
||||
)
|
||||
message_objects.append(msg)
|
||||
|
||||
# 按时间正序排列
|
||||
message_objects.reverse()
|
||||
return message_objects
|
||||
77
src/plugins/chat/willing_manager.py
Normal file
77
src/plugins/chat/willing_manager.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import asyncio
|
||||
|
||||
class WillingManager:
|
||||
def __init__(self):
|
||||
self.group_reply_willing = {} # 存储每个群的回复意愿
|
||||
self._decay_task = None
|
||||
self._started = False
|
||||
|
||||
async def _decay_reply_willing(self):
|
||||
"""定期衰减回复意愿"""
|
||||
while True:
|
||||
await asyncio.sleep(3)
|
||||
for group_id in self.group_reply_willing:
|
||||
# 每分钟衰减10%的回复意愿
|
||||
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
|
||||
|
||||
def get_willing(self, group_id: int) -> float:
|
||||
"""获取指定群组的回复意愿"""
|
||||
return self.group_reply_willing.get(group_id, 0)
|
||||
|
||||
def set_willing(self, group_id: int, willing: float):
|
||||
"""设置指定群组的回复意愿"""
|
||||
self.group_reply_willing[group_id] = willing
|
||||
|
||||
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False) -> float:
|
||||
"""改变指定群组的回复意愿并返回回复概率"""
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
|
||||
if topic and current_willing < 1:
|
||||
current_willing += 0.6
|
||||
elif topic:
|
||||
current_willing += 0.05
|
||||
|
||||
if is_mentioned_bot and current_willing < 1.0:
|
||||
current_willing += 1
|
||||
elif is_mentioned_bot:
|
||||
current_willing += 0.05
|
||||
|
||||
if is_emoji:
|
||||
current_willing *= 0.2
|
||||
|
||||
self.group_reply_willing[group_id] = min(current_willing, 3.0)
|
||||
|
||||
reply_probability = (current_willing - 0.5) * 2
|
||||
if group_id not in config.talk_allowed_groups:
|
||||
current_willing = 0
|
||||
reply_probability = 0
|
||||
|
||||
if group_id in config.talk_frequency_down_groups:
|
||||
reply_probability = reply_probability / 2
|
||||
|
||||
if is_mentioned_bot and user_id == int(1026294844):
|
||||
reply_probability = 1
|
||||
|
||||
return reply_probability
|
||||
|
||||
def change_reply_willing_sent(self, group_id: int):
|
||||
"""发送消息后降低群组的回复意愿"""
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
self.group_reply_willing[group_id] = max(0, current_willing - 1.8)
|
||||
|
||||
def change_reply_willing_after_sent(self, group_id: int):
|
||||
"""发送消息后提高群组的回复意愿"""
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
# 如果当前意愿小于1,增加0.3的意愿值
|
||||
if current_willing < 1:
|
||||
self.group_reply_willing[group_id] = min(1, current_willing + 0.8)
|
||||
|
||||
async def ensure_started(self):
|
||||
"""确保衰减任务已启动"""
|
||||
if not self._started:
|
||||
if self._decay_task is None:
|
||||
self._decay_task = asyncio.create_task(self._decay_reply_willing())
|
||||
self._started = True
|
||||
|
||||
# 创建全局实例
|
||||
willing_manager = WillingManager()
|
||||
156
src/plugins/schedule/schedule_generator.py
Normal file
156
src/plugins/schedule/schedule_generator.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import datetime
|
||||
from typing import List, Dict
|
||||
from .schedule_llm_module import LLMModel
|
||||
from ...common.database import Database # 使用正确的导入语法
|
||||
|
||||
|
||||
# import sys
|
||||
# sys.path.append("C:/GitHub/MegMeg-bot") # 添加项目根目录到 Python 路径
|
||||
# from src.plugins.schedule.schedule_llm_module import LLMModel
|
||||
# from src.common.database import Database # 使用正确的导入语法
|
||||
|
||||
Database.initialize(
|
||||
"127.0.0.1",
|
||||
27017,
|
||||
"MegBot"
|
||||
)
|
||||
|
||||
class ScheduleGenerator:
|
||||
def __init__(self):
|
||||
self.llm_scheduler = LLMModel(model_name="Pro/deepseek-ai/DeepSeek-V3")
|
||||
self.db = Database.get_instance()
|
||||
|
||||
today = datetime.datetime.now()
|
||||
tomorrow = datetime.datetime.now() + datetime.timedelta(days=1)
|
||||
yesterday = datetime.datetime.now() - datetime.timedelta(days=1)
|
||||
|
||||
self.today_schedule_text, self.today_schedule = self.generate_daily_schedule(target_date=today)
|
||||
|
||||
self.tomorrow_schedule_text, self.tomorrow_schedule = self.generate_daily_schedule(target_date=tomorrow,read_only=True)
|
||||
self.yesterday_schedule_text, self.yesterday_schedule = self.generate_daily_schedule(target_date=yesterday,read_only=True)
|
||||
|
||||
def generate_daily_schedule(self, target_date: datetime.datetime = None,read_only:bool = False) -> Dict[str, str]:
|
||||
if target_date is None:
|
||||
target_date = datetime.datetime.now()
|
||||
|
||||
date_str = target_date.strftime("%Y-%m-%d")
|
||||
weekday = target_date.strftime("%A")
|
||||
|
||||
|
||||
schedule_text = str
|
||||
|
||||
existing_schedule = self.db.db.schedule.find_one({"date": date_str})
|
||||
if existing_schedule:
|
||||
print(f"{date_str}的日程已存在:")
|
||||
schedule_text = existing_schedule["schedule"]
|
||||
# print(self.schedule_text)
|
||||
|
||||
elif read_only == False:
|
||||
print(f"{date_str}的日程不存在,准备生成新的日程。")
|
||||
prompt = f"""我是麦麦,一个地质学大二女大学生,喜欢刷qq,贴吧,知乎和小红书,请为我生成{date_str}({weekday})的日程安排,包括:
|
||||
1. 早上的学习和工作安排
|
||||
2. 下午的活动和任务
|
||||
3. 晚上的计划和休息时间
|
||||
请按照时间顺序列出具体时间点和对应的活动,用一个时间点而不是时间段来表示时间,用逗号,隔开时间与活动,格式为"时间,活动",例如"08:00,起床"。"""
|
||||
|
||||
schedule_text, _ = self.llm_scheduler.generate_response(prompt)
|
||||
# print(self.schedule_text)
|
||||
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
|
||||
else:
|
||||
print(f"{date_str}的日程不存在。")
|
||||
schedule_text = "忘了"
|
||||
|
||||
return schedule_text,None
|
||||
|
||||
schedule_form = self._parse_schedule(schedule_text)
|
||||
return schedule_text,schedule_form
|
||||
|
||||
def _parse_schedule(self, schedule_text: str) -> Dict[str, str]:
|
||||
"""解析日程文本,转换为时间和活动的字典"""
|
||||
schedule_dict = {}
|
||||
# 按行分割日程文本
|
||||
lines = schedule_text.strip().split('\n')
|
||||
for line in lines:
|
||||
# print(line)
|
||||
if ',' in line:
|
||||
# 假设格式为 "时间: 活动"
|
||||
time_str, activity = line.split(',', 1)
|
||||
# print(time_str)
|
||||
# print(activity)
|
||||
schedule_dict[time_str.strip()] = activity.strip()
|
||||
return schedule_dict
|
||||
|
||||
def _parse_time(self, time_str: str) -> str:
|
||||
"""解析时间字符串,转换为时间"""
|
||||
return datetime.datetime.strptime(time_str, "%H:%M")
|
||||
|
||||
def get_current_task(self) -> str:
|
||||
"""获取当前时间应该进行的任务"""
|
||||
current_time = datetime.datetime.now().strftime("%H:%M")
|
||||
|
||||
# 找到最接近当前时间的任务
|
||||
closest_time = None
|
||||
min_diff = float('inf')
|
||||
|
||||
# 检查今天的日程
|
||||
for time_str in self.today_schedule.keys():
|
||||
diff = abs(self._time_diff(current_time, time_str))
|
||||
if closest_time is None or diff < min_diff:
|
||||
closest_time = time_str
|
||||
min_diff = diff
|
||||
|
||||
# 检查昨天的日程中的晚间任务
|
||||
if self.yesterday_schedule:
|
||||
for time_str in self.yesterday_schedule.keys():
|
||||
if time_str >= "20:00": # 只考虑晚上8点之后的任务
|
||||
# 计算与昨天这个时间点的差异(需要加24小时)
|
||||
diff = abs(self._time_diff(current_time, time_str))
|
||||
if diff < min_diff:
|
||||
closest_time = time_str
|
||||
min_diff = diff
|
||||
return closest_time, self.yesterday_schedule[closest_time]
|
||||
|
||||
if closest_time:
|
||||
return closest_time, self.today_schedule[closest_time]
|
||||
return "摸鱼"
|
||||
|
||||
def _time_diff(self, time1: str, time2: str) -> int:
|
||||
"""计算两个时间字符串之间的分钟差"""
|
||||
t1 = datetime.datetime.strptime(time1, "%H:%M")
|
||||
t2 = datetime.datetime.strptime(time2, "%H:%M")
|
||||
diff = int((t2 - t1).total_seconds() / 60)
|
||||
# 考虑时间的循环性
|
||||
if diff < -720:
|
||||
diff += 1440 # 加一天的分钟
|
||||
elif diff > 720:
|
||||
diff -= 1440 # 减一天的分钟
|
||||
# print(f"时间1[{time1}]: 时间2[{time2}],差值[{diff}]分钟")
|
||||
return diff
|
||||
|
||||
def print_schedule(self):
|
||||
"""打印完整的日程安排"""
|
||||
|
||||
print("\n=== 今日日程安排 ===")
|
||||
for time_str, activity in self.today_schedule.items():
|
||||
print(f"时间[{time_str}]: 活动[{activity}]")
|
||||
print("==================\n")
|
||||
|
||||
# def main():
|
||||
# # 使用示例
|
||||
# scheduler = ScheduleGenerator()
|
||||
# # new_schedule = scheduler.generate_daily_schedule()
|
||||
# scheduler.print_schedule()
|
||||
# print("\n当前任务:")
|
||||
# print(scheduler.get_current_task())
|
||||
|
||||
# print("昨天日程:")
|
||||
# print(scheduler.yesterday_schedule)
|
||||
# print("今天日程:")
|
||||
# print(scheduler.today_schedule)
|
||||
# print("明天日程:")
|
||||
# print(scheduler.tomorrow_schedule)
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
|
||||
bot_schedule = ScheduleGenerator()
|
||||
55
src/plugins/schedule/schedule_llm_module.py
Normal file
55
src/plugins/schedule/schedule_llm_module.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from typing import Tuple, Union
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
class LLMModel:
|
||||
# def __init__(self, model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", **kwargs):
|
||||
def __init__(self, model_name="Pro/deepseek-ai/DeepSeek-R1", **kwargs):
|
||||
self.model_name = model_name
|
||||
self.params = kwargs
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
self.base_url = os.getenv("SILICONFLOW_BASE_URL")
|
||||
|
||||
def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||
"""根据输入的提示生成模型的响应"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.9,
|
||||
**self.params
|
||||
}
|
||||
|
||||
# 发送请求到完整的chat/completions端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
|
||||
try:
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
response.raise_for_status() # 检查响应状态
|
||||
|
||||
result = response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content # 返回内容和推理内容
|
||||
return "没有返回结果", "" # 返回两个值
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"请求失败: {str(e)}", "" # 返回错误信息和空字符串
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
model = LLMModel() # 默认使用 DeepSeek-V3 模型
|
||||
prompt = "你好,你喜欢我吗?"
|
||||
result, reasoning = model.generate_response(prompt)
|
||||
print("回复内容:", result)
|
||||
print("推理内容:", reasoning)
|
||||
Reference in New Issue
Block a user