Merge branch 'debug' into main

This commit is contained in:
SengokuCola
2025-03-12 00:20:54 +08:00
committed by GitHub
64 changed files with 6255 additions and 2846 deletions

View File

@@ -1,12 +1,10 @@
import asyncio
import os
import random
import time
import os
from loguru import logger
from nonebot import get_driver, on_command, on_message, require
from nonebot import get_driver, on_message, require
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment
from nonebot.rule import to_me
from nonebot.typing import T_State
from ...common.database import Database
@@ -18,6 +16,11 @@ from .config import global_config
from .emoji_manager import emoji_manager
from .relationship_manager import relationship_manager
from .willing_manager import willing_manager
from .chat_stream import chat_manager
from ..memory_system.memory import hippocampus, memory_graph
from .bot import ChatBot
from .message_sender import message_manager, message_sender
# 创建LLM统计实例
llm_stats = LLMStatistics("llm_statistics.txt")
@@ -30,27 +33,21 @@ driver = get_driver()
config = driver.config
Database.initialize(
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source= config.MONGODB_AUTH_SOURCE
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
print("\033[1;32m[初始化数据库完成]\033[0m")
logger.success("初始化数据库成功")
# 导入其他模块
from ..memory_system.memory import hippocampus, memory_graph
from .bot import ChatBot
# from .message_send_control import message_sender
from .message_sender import message_manager, message_sender
# 初始化表情管理器
emoji_manager.initialize()
print(f"\033[1;32m正在唤醒{global_config.BOT_NICKNAME}......\033[0m")
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
# 创建机器人实例
chat_bot = ChatBot()
# 注册群消息处理器
@@ -59,69 +56,80 @@ group_msg = on_message(priority=5)
scheduler = require("nonebot_plugin_apscheduler").scheduler
@driver.on_startup
async def start_background_tasks():
"""启动后台任务"""
# 启动LLM统计
llm_stats.start()
print("\033[1;32m[初始化]\033[0m LLM统计功能启动")
logger.success("LLM统计功能启动成功")
# 初始化并启动情绪管理器
mood_manager = MoodManager.get_instance()
mood_manager.start_mood_update(update_interval=global_config.mood_update_interval)
print("\033[1;32m[初始化]\033[0m 情绪管理器启动")
logger.success("情绪管理器启动成功")
# 只启动表情包管理任务
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
await bot_schedule.initialize()
bot_schedule.print_schedule()
@driver.on_startup
async def init_relationships():
"""在 NoneBot2 启动时初始化关系管理器"""
print("\033[1;32m[初始化]\033[0m 正在加载用户关系数据...")
logger.debug("正在加载用户关系数据...")
await relationship_manager.load_all_relationships()
asyncio.create_task(relationship_manager._start_relationship_manager())
@driver.on_bot_connect
async def _(bot: Bot):
"""Bot连接成功时的处理"""
global _message_manager_started
print(f"\033[1;38;5;208m-----------{global_config.BOT_NICKNAME}成功连接!-----------\033[0m")
logger.debug(f"-----------{global_config.BOT_NICKNAME}成功连接!-----------")
await willing_manager.ensure_started()
message_sender.set_bot(bot)
print("\033[1;38;5;208m-----------消息发送器已启动!-----------\033[0m")
logger.success("-----------消息发送器已启动!-----------")
if not _message_manager_started:
asyncio.create_task(message_manager.start_processor())
_message_manager_started = True
print("\033[1;38;5;208m-----------消息处理器已启动!-----------\033[0m")
logger.success("-----------消息处理器已启动!-----------")
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
logger.success("-----------开始偷表情包!-----------")
asyncio.create_task(chat_manager._initialize())
asyncio.create_task(chat_manager._auto_save_task())
@group_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
await chat_bot.handle_message(event, bot)
# 添加build_memory定时任务
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task():
"""每build_memory_interval秒执行一次记忆构建"""
print("\033[1;32m[记忆构建]\033[0m -------------------------------------------开始构建记忆-------------------------------------------")
logger.debug(
"[记忆构建]"
"------------------------------------开始构建记忆--------------------------------------")
start_time = time.time()
await hippocampus.operation_build_memory(chat_size=20)
end_time = time.time()
print(f"\033[1;32m[记忆构建]\033[0m -------------------------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒-------------------------------------------")
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
logger.success(
f"[记忆构建]--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} "
"秒-------------------------------------------")
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
async def forget_memory_task():
"""每30秒执行一次记忆构建"""
# print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
# await hippocampus.operation_forget_topic(percentage=0.1)
# print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory")
async def merge_memory_task():
@@ -130,9 +138,9 @@ async def merge_memory_task():
# await hippocampus.operation_merge_memory(percentage=0.1)
# print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
@scheduler.scheduled_job("interval", seconds=30, id="print_mood")
async def print_mood_task():
"""每30秒打印一次情绪状态"""
mood_manager = MoodManager.get_instance()
mood_manager.print_mood_status()

View File

@@ -1,26 +1,27 @@
import re
import time
from random import random
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
from ..memory_system.memory import hippocampus
from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config
from .cq_code import CQCode # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator
from .message import (
Message,
Message_Sending,
Message_Thinking, # 导入 Message_Thinking 类
MessageSet,
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message_cq import (
MessageRecvCQ,
)
from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器
from .relationship_manager import relationship_manager
from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
from .utils import calculate_typing_time, is_mentioned_bot_in_message
from .utils_image import image_path_to_base64
from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg
class ChatBot:
@@ -31,10 +32,10 @@ class ChatBot:
self._started = False
self.mood_manager = MoodManager.get_instance() # 获取情绪管理器单例
self.mood_manager.start_mood_update() # 启动情绪更新
self.emoji_chance = 0.2 # 发送表情包的基础概率
# self.message_streams = MessageStreamContainer()
async def _ensure_started(self):
"""确保所有任务已启动"""
if not self._started:
@@ -42,185 +43,232 @@ class ChatBot:
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的群消息"""
if event.group_id not in global_config.talk_allowed_groups:
return
self.bot = bot # 更新 bot 实例
try:
group_info_api = await bot.get_group_info(group_id=event.group_id)
logger.info(f"成功获取群信息: {group_info_api}")
group_name = group_info_api["group_name"]
except Exception as e:
logger.error(f"获取群信息失败: {str(e)}")
group_name = None
# 白名单设定由nontbot侧完成
# 消息过滤涉及到config有待更新
if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
if event.user_id in global_config.ban_user_id:
return
group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
await relationship_manager.update_relationship(user_id = event.user_id, data = sender_info)
await relationship_manager.update_relationship_value(user_id = event.user_id, relationship_value = 0.5)
message = Message(
group_id=event.group_id,
user_info = UserInfo(
user_id=event.user_id,
message_id=event.message_id,
user_cardname=sender_info['card'],
raw_message=str(event.original_message),
plain_text=event.get_plaintext(),
reply_message=event.reply,
user_nickname=event.sender.nickname,
user_cardname=event.sender.card or None,
platform="qq",
)
await message.initialize()
group_info = GroupInfo(
group_id=event.group_id,
group_name=group_name, # 使用获取到的群名称或None
platform="qq",
)
message_cq = MessageRecvCQ(
message_id=event.message_id,
user_info=user_info,
raw_message=str(event.original_message),
group_info=group_info,
reply_message=event.reply,
platform="qq",
)
message_json = message_cq.to_dict()
# 进入maimbot
message = MessageRecv(message_json)
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
messageinfo = message.message_info
# 消息过滤涉及到config有待更新
chat = await chat_manager.get_or_create_stream(
platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo
)
message.update_chat_stream(chat)
await relationship_manager.update_relationship(
chat_stream=chat,
)
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value=0.5)
await message.process()
# 过滤词
for word in global_config.ban_words:
if word in message.detailed_plain_text:
logger.info(f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}")
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered")
if word in message.processed_plain_text:
logger.info(f"[群{groupinfo.group_id}]{userinfo.user_nickname}:{message.processed_plain_text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered")
return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
# 正则表达式过滤
for pattern in global_config.ban_msgs_regex:
if re.search(pattern, message.raw_message):
logger.info(
f"[群{message.message_info.group_info.group_id}]{message.user_nickname}:{message.raw_message}"
)
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
# topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
topic = ''
topic = ""
interested_rate = 0
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text)/100
print(f"\033[1;32m[记忆激活]\033[0m {message.processed_plain_text}的激活度:---------------------------------------{interested_rate}\n")
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100
logger.debug(f"{message.processed_plain_text}的激活度:{interested_rate}")
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
await self.storage.store_message(message, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
reply_probability = willing_manager.change_reply_willing_received(
event.group_id,
topic[0] if topic else None,
is_mentioned,
global_config,
event.user_id,
message.is_emoji,
interested_rate
await self.storage.store_message(message, chat, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_message(message)
reply_probability = await willing_manager.change_reply_willing_received(
chat_stream=chat,
topic=topic[0] if topic else None,
is_mentioned_bot=is_mentioned,
config=global_config,
is_emoji=message.is_emoji,
interested_rate=interested_rate,
)
current_willing = willing_manager.get_willing(event.group_id)
print(f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]\033[0m")
current_willing = willing_manager.get_willing(chat_stream=chat)
logger.info(
f"[{current_time}][{chat.group_info.group_id}]{chat.user_info.user_nickname}:"
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]"
)
response = None
response = ""
if random() < reply_probability:
tinking_time_point = round(time.time(), 2)
think_id = 'mt' + str(tinking_time_point)
thinking_message = Message_Thinking(message=message,message_id=think_id)
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ, user_nickname=global_config.BOT_NICKNAME, platform=messageinfo.platform
)
thinking_time_point = round(time.time(), 2)
think_id = "mt" + str(thinking_time_point)
thinking_message = MessageThinking(
message_id=think_id, chat_stream=chat, bot_user_info=bot_user_info, reply=message
)
message_manager.add_message(thinking_message)
willing_manager.change_reply_willing_sent(thinking_message.group_id)
response,raw_content = await self.gpt.generate_response(message)
willing_manager.change_reply_willing_sent(chat)
response, raw_content = await self.gpt.generate_response(message)
# print(f"response: {response}")
if response:
container = message_manager.get_container(event.group_id)
# print(f"有response: {response}")
container = message_manager.get_container(chat.stream_id)
thinking_message = None
# 找到message,删除
# print(f"开始找思考消息")
for msg in container.messages:
if isinstance(msg, Message_Thinking) and msg.message_id == think_id:
if isinstance(msg, MessageThinking) and msg.message_info.message_id == think_id:
# print(f"找到思考消息: {msg}")
thinking_message = msg
container.messages.remove(msg)
# print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
break
# 如果找不到思考消息,直接返回
if not thinking_message:
print(f"\033[1;33m[警告]\033[0m 未找到对应的思考消息,可能已超时被移除")
logger.warning("未找到对应的思考消息,可能已超时被移除")
return
#记录开始思考的时间,避免从思考到回复的时间太久
# 记录开始思考的时间,避免从思考到回复的时间太久
thinking_start_time = thinking_message.thinking_start_time
message_set = MessageSet(event.group_id, global_config.BOT_QQ, think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
#计算打字时间1是为了模拟打字2是避免多条回复乱序
message_set = MessageSet(chat, think_id)
# 计算打字时间1是为了模拟打字2是避免多条回复乱序
accu_typing_time = 0
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
mark_head = False
for msg in response:
# print(f"\033[1;32m[回复内容]\033[0m {msg}")
#通过时间改变时间戳
# 通过时间改变时间戳
typing_time = calculate_typing_time(msg)
print(f"typing_time: {typing_time}")
accu_typing_time += typing_time
timepoint = tinking_time_point + accu_typing_time
bot_message = Message_Sending(
group_id=event.group_id,
user_id=global_config.BOT_QQ,
timepoint = thinking_time_point + accu_typing_time
message_segment = Seg(type="text", data=msg)
print(f"message_segment: {message_segment}")
bot_message = MessageSending(
message_id=think_id,
raw_message=msg,
plain_text=msg,
processed_plain_text=msg,
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=timepoint, #记录了回复生成的时间
thinking_start_time=thinking_start_time, #记录了思考开始的时间
reply_message_id=message.message_id
chat_stream=chat,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=message,
is_head=not mark_head,
is_emoji=False,
)
await bot_message.initialize()
print(f"bot_message: {bot_message}")
if not mark_head:
bot_message.is_head = True
mark_head = True
print(f"添加消息到message_set: {bot_message}")
message_set.add_message(bot_message)
#message_set 可以直接加入 message_manager
# message_set 可以直接加入 message_manager
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
print("添加message_set到message_manager")
message_manager.add_message(message_set)
bot_response_time = tinking_time_point
bot_response_time = thinking_time_point
if random() < global_config.emoji_chance:
emoji_raw = await emoji_manager.get_emoji_for_text(response)
# 检查是否 <没有找到> emoji
if emoji_raw != None:
emoji_path,discription = emoji_raw
emoji_path, description = emoji_raw
emoji_cq = image_path_to_base64(emoji_path)
emoji_cq = CQCode.create_emoji_cq(emoji_path)
if random() < 0.5:
bot_response_time = tinking_time_point - 1
bot_response_time = thinking_time_point - 1
else:
bot_response_time = bot_response_time + 1
bot_message = Message_Sending(
group_id=event.group_id,
user_id=global_config.BOT_QQ,
message_id=0,
raw_message=emoji_cq,
plain_text=emoji_cq,
processed_plain_text=emoji_cq,
detailed_plain_text=discription,
user_nickname=global_config.BOT_NICKNAME,
group_name=message.group_name,
time=bot_response_time,
message_segment = Seg(type="emoji", data=emoji_cq)
bot_message = MessageSending(
message_id=think_id,
chat_stream=chat,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=message,
is_head=False,
is_emoji=True,
translate_cq=False,
thinking_start_time=thinking_start_time,
# reply_message_id=message.message_id
)
await bot_message.initialize()
message_manager.add_message(bot_message)
emotion = await self.gpt._get_emotion_tags(raw_content)
print(f"'{response}' 获取到的情感标签为:{emotion}")
valuedict={
'happy': 0.5,
'angry': -1,
'sad': -0.5,
'surprised': 0.2,
'disgusted': -1.5,
'fearful': -0.7,
'neutral': 0.1
logger.debug(f"'{response}' 获取到的情感标签为:{emotion}")
valuedict = {
"happy": 0.5,
"angry": -1,
"sad": -0.5,
"surprised": 0.2,
"disgusted": -1.5,
"fearful": -0.7,
"neutral": 0.1,
}
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
await relationship_manager.update_relationship_value(
chat_stream=chat, relationship_value=valuedict[emotion[0]]
)
# 使用情绪管理器更新情绪
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
# willing_manager.change_reply_willing_after_sent(event.group_id)
# willing_manager.change_reply_willing_after_sent(
# chat_stream=chat
# )
# 创建全局ChatBot实例
chat_bot = ChatBot()
chat_bot = ChatBot()

View File

@@ -0,0 +1,226 @@
import asyncio
import hashlib
import time
import copy
from typing import Dict, Optional
from loguru import logger
from ...common.database import Database
from .message_base import GroupInfo, UserInfo
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
def __init__(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = (
data.get("create_time", int(time.time())) if data else int(time.time())
)
self.last_active_time = (
data.get("last_active_time", self.create_time) if data else self.create_time
)
self.saved = False
def to_dict(self) -> dict:
"""转换为字典格式"""
result = {
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": self.user_info.to_dict() if self.user_info else None,
"group_info": self.group_info.to_dict() if self.group_info else None,
"create_time": self.create_time,
"last_active_time": self.last_active_time,
}
return result
@classmethod
def from_dict(cls, data: dict) -> "ChatStream":
"""从字典创建实例"""
user_info = (
UserInfo(**data.get("user_info", {})) if data.get("user_info") else None
)
group_info = (
GroupInfo(**data.get("group_info", {})) if data.get("group_info") else None
)
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info,
group_info=group_info,
data=data,
)
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = int(time.time())
self.saved = False
class ChatManager:
"""聊天管理器,管理所有聊天流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.db = Database.get_instance()
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
# asyncio.create_task(self._initialize())
# # 启动自动保存任务
# asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
try:
await self.load_all_streams()
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
while True:
await asyncio.sleep(300) # 每5分钟保存一次
try:
await self._save_all_streams()
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.db.list_collection_names():
self.db.db.create_collection("chat_streams")
# 创建索引
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
def _generate_stream_id(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> str:
"""生成聊天流唯一ID"""
if group_info:
# 组合关键信息
components = [
platform,
str(group_info.group_id)
]
else:
components = [
platform,
str(user_info.user_id),
"private"
]
# 使用MD5生成唯一ID
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> ChatStream:
"""获取或创建聊天流
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息(可选)
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream=copy.deepcopy(stream)
stream.user_info = user_info
if group_info:
stream.group_info = group_info
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
else:
# 创建新的聊天流
stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
)
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return copy.deepcopy(stream)
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> Optional[ChatStream]:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
async def _save_all_streams(self):
"""保存所有聊天流"""
for stream in self.streams.values():
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream
# 创建全局单例
chat_manager = ChatManager()

View File

@@ -1,50 +1,55 @@
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Dict, List, Optional
import tomli
from loguru import logger
from packaging import version
from packaging.version import Version, InvalidVersion
from packaging.specifiers import SpecifierSet,InvalidSpecifier
from packaging.specifiers import SpecifierSet, InvalidSpecifier
@dataclass
class BotConfig:
"""机器人配置类"""
"""机器人配置类"""
INNER_VERSION: Version = None
BOT_QQ: Optional[int] = 1
BOT_NICKNAME: Optional[str] = None
BOT_ALIAS_NAMES: List[str] = field(default_factory=list) # 别名,可以通过这个叫它
# 消息处理相关配置
MIN_TEXT_LENGTH: int = 2 # 最小处理文本长度
MAX_CONTEXT_SIZE: int = 15 # 上下文最大消息数
emoji_chance: float = 0.2 # 发送表情包的基础概率
ENABLE_PIC_TRANSLATE: bool = True # 是否启用图片翻译
talk_allowed_groups = set()
talk_frequency_down_groups = set()
thinking_timeout: int = 100 # 思考时间
response_willing_amplifier: float = 1.0 # 回复意愿放大系数
response_interested_rate_amplifier: float = 1.0 # 回复兴趣度放大系数
down_frequency_rate: float = 3.5 # 降低回复频率的群组回复意愿降低系数
ban_user_id = set()
build_memory_interval: int = 30 # 记忆构建间隔(秒)
forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
EMOJI_SAVE: bool = True # 偷表情包
EMOJI_CHECK: bool = False #是否开启过滤
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
EMOJI_CHECK: bool = False # 是否开启过滤
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
ban_words = set()
ban_msgs_regex = set()
max_response_length: int = 1024 # 最大回复长度
# 模型配置
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
@@ -60,83 +65,86 @@ class BotConfig:
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
enable_advance_output: bool = False # 是否启用高级输出
enable_kuuki_read: bool = True # 是否启用读空气功能
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
mood_intensity_factor: float = 0.7 # 情绪强度因子
enable_kuuki_read: bool = True # 是否启用读空气功能
enable_debug_output: bool = False # 是否启用调试输出
keywords_reaction_rules = [] # 关键词回复规则
mood_update_interval: float = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate: float = 0.95 # 情绪衰减率
mood_intensity_factor: float = 0.7 # 情绪强度因子
chinese_typo_enable=True # 是否启用中文错别字生成器
chinese_typo_error_rate=0.03 # 单字替换概率
chinese_typo_min_freq=7 # 最小字频阈值
chinese_typo_tone_error_rate=0.2 # 声调错误概率
chinese_typo_word_replace_rate=0.02 # 整词替换概率
keywords_reaction_rules = [] # 关键词回复规则
chinese_typo_enable = True # 是否启用中文错别字生成器
chinese_typo_error_rate = 0.03 # 单字替换概率
chinese_typo_min_freq = 7 # 最小字频阈值
chinese_typo_tone_error_rate = 0.2 # 声调错误概率
chinese_typo_word_replace_rate = 0.02 # 整词替换概率
# 默认人设
PROMPT_PERSONALITY=[
PROMPT_PERSONALITY = [
"曾经是一个学习地质的女大学生,现在学习心理学和脑科学,你会刷贴吧",
"是一个女大学生,你有黑色头发,你会刷小红书",
"是一个女大学生你会刷b站对ACG文化感兴趣"
"是一个女大学生你会刷b站对ACG文化感兴趣",
]
PROMPT_SCHEDULE_GEN="一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书"
PERSONALITY_1: float = 0.6 # 第一种人格概率
PERSONALITY_2: float = 0.3 # 第种人格概率
PERSONALITY_3: float = 0.1 # 第种人格概率
PROMPT_SCHEDULE_GEN = "一个曾经学习地质,现在学习心理学和脑科学的女大学生喜欢刷qq贴吧知乎和小红书"
PERSONALITY_1: float = 0.6 # 第种人格概率
PERSONALITY_2: float = 0.3 # 第种人格概率
PERSONALITY_3: float = 0.1 # 第三种人格概率
memory_ban_words: list = field(
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
) # 添加新的配置项默认值
@staticmethod
def get_config_dir() -> str:
"""获取配置文件目录"""
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
config_dir = os.path.join(root_dir, 'config')
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
config_dir = os.path.join(root_dir, "config")
if not os.path.exists(config_dir):
os.makedirs(config_dir)
return config_dir
@classmethod
def convert_to_specifierset(cls, value: str) -> SpecifierSet:
"""将 字符串 版本表达式转换成 SpecifierSet
Args:
value[str]: 版本表达式(字符串)
Returns:
SpecifierSet
SpecifierSet
"""
try:
converted = SpecifierSet(value)
except InvalidSpecifier as e:
logger.error(
f"{value} 分类使用了错误的版本约束表达式\n",
"请阅读 https://semver.org/lang/zh-CN/ 修改代码"
)
except InvalidSpecifier:
logger.error(f"{value} 分类使用了错误的版本约束表达式\n", "请阅读 https://semver.org/lang/zh-CN/ 修改代码")
exit(1)
return converted
@classmethod
def get_config_version(cls, toml: dict) -> Version:
"""提取配置文件的 SpecifierSet 版本数据
"""提取配置文件的 SpecifierSet 版本数据
Args:
toml[dict]: 输入的配置文件字典
Returns:
Version
Version
"""
if 'inner' in toml:
if "inner" in toml:
try:
config_version : str = toml["inner"]["version"]
config_version: str = toml["inner"]["version"]
except KeyError as e:
logger.error(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件")
logger.error("配置文件中 inner 段 不存在, 这是错误的配置文件")
raise KeyError(f"配置文件中 inner 段 不存在 {e}, 这是错误的配置文件") from e
else:
toml["inner"] = { "version": "0.0.0" }
toml["inner"] = {"version": "0.0.0"}
config_version = toml["inner"]["version"]
try:
ver = version.parse(config_version)
except InvalidVersion as e:
@@ -145,41 +153,41 @@ class BotConfig:
"请阅读 https://semver.org/lang/zh-CN/ 修改配置,并参考本项目指定的模板进行修改\n"
"本项目在不同的版本下有不同的模板,请注意识别"
)
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n")
raise InvalidVersion("配置文件中 inner段 的 version 键是错误的版本描述\n") from e
return ver
@classmethod
def load_config(cls, config_path: str = None) -> "BotConfig":
"""从TOML配置文件加载配置"""
config = cls()
def personality(parent: dict):
personality_config=parent['personality']
personality=personality_config.get('prompt_personality')
personality_config = parent["personality"]
personality = personality_config.get("prompt_personality")
if len(personality) >= 2:
logger.info(f"载入自定义人格:{personality}")
config.PROMPT_PERSONALITY=personality_config.get('prompt_personality',config.PROMPT_PERSONALITY)
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)}")
config.PROMPT_SCHEDULE_GEN=personality_config.get('prompt_schedule',config.PROMPT_SCHEDULE_GEN)
logger.debug(f"载入自定义人格:{personality}")
config.PROMPT_PERSONALITY = personality_config.get("prompt_personality", config.PROMPT_PERSONALITY)
logger.info(f"载入自定义日程prompt:{personality_config.get('prompt_schedule', config.PROMPT_SCHEDULE_GEN)}")
config.PROMPT_SCHEDULE_GEN = personality_config.get("prompt_schedule", config.PROMPT_SCHEDULE_GEN)
if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
config.PERSONALITY_1=personality_config.get('personality_1_probability',config.PERSONALITY_1)
config.PERSONALITY_2=personality_config.get('personality_2_probability',config.PERSONALITY_2)
config.PERSONALITY_3=personality_config.get('personality_3_probability',config.PERSONALITY_3)
config.PERSONALITY_1 = personality_config.get("personality_1_probability", config.PERSONALITY_1)
config.PERSONALITY_2 = personality_config.get("personality_2_probability", config.PERSONALITY_2)
config.PERSONALITY_3 = personality_config.get("personality_3_probability", config.PERSONALITY_3)
def emoji(parent: dict):
emoji_config = parent["emoji"]
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL)
config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt',config.EMOJI_CHECK_PROMPT)
config.EMOJI_SAVE = emoji_config.get('auto_save',config.EMOJI_SAVE)
config.EMOJI_CHECK = emoji_config.get('enable_check',config.EMOJI_CHECK)
config.EMOJI_CHECK_PROMPT = emoji_config.get("check_prompt", config.EMOJI_CHECK_PROMPT)
config.EMOJI_SAVE = emoji_config.get("auto_save", config.EMOJI_SAVE)
config.EMOJI_CHECK = emoji_config.get("enable_check", config.EMOJI_CHECK)
def cq_code(parent: dict):
cq_code_config = parent["cq_code"]
config.ENABLE_PIC_TRANSLATE = cq_code_config.get("enable_pic_translate", config.ENABLE_PIC_TRANSLATE)
def bot(parent: dict):
# 机器人基础配置
bot_config = parent["bot"]
@@ -187,16 +195,21 @@ class BotConfig:
config.BOT_QQ = int(bot_qq)
config.BOT_NICKNAME = bot_config.get("nickname", config.BOT_NICKNAME)
if config.INNER_VERSION in SpecifierSet(">=0.0.5"):
config.BOT_ALIAS_NAMES = bot_config.get("alias_names", config.BOT_ALIAS_NAMES)
def response(parent: dict):
response_config = parent["response"]
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get(
"model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY
)
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
def model(parent: dict):
# 加载模型配置
model_config:dict = parent["model"]
model_config: dict = parent["model"]
config_list = [
"llm_reasoning",
@@ -208,29 +221,23 @@ class BotConfig:
"llm_emotion_judge",
"vlm",
"embedding",
"moderation"
"moderation",
]
for item in config_list:
if item in model_config:
cfg_item:dict = model_config[item]
cfg_item: dict = model_config[item]
# base_url 的例子: SILICONFLOW_BASE_URL
# key 的例子: SILICONFLOW_KEY
cfg_target = {
"name" : "",
"base_url" : "",
"key" : "",
"pri_in" : 0,
"pri_out" : 0
}
cfg_target = {"name": "", "base_url": "", "key": "", "pri_in": 0, "pri_out": 0}
if config.INNER_VERSION in SpecifierSet("<=0.0.0"):
cfg_target = cfg_item
elif config.INNER_VERSION in SpecifierSet(">=0.0.1"):
stable_item = ["name","pri_in","pri_out"]
pricing_item = ["pri_in","pri_out"]
stable_item = ["name", "pri_in", "pri_out"]
pricing_item = ["pri_in", "pri_out"]
# 从配置中原始拷贝稳定字段
for i in stable_item:
# 如果 字段 属于计费项 且获取不到,那默认值是 0
@@ -241,21 +248,19 @@ class BotConfig:
try:
cfg_target[i] = cfg_item[i]
except KeyError as e:
logger.error(f"{item} 中的必要字段 {e} 不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查")
logger.error(f"{item} 中的必要字段不存在,请检查")
raise KeyError(f"{item} 中的必要字段 {e} 不存在,请检查") from e
provider = cfg_item.get("provider")
if provider == None:
if provider is None:
logger.error(f"provider 字段在模型配置 {item} 中不存在,请检查")
raise KeyError(f"provider 字段在模型配置 {item} 中不存在,请检查")
cfg_target["base_url"] = f"{provider}_BASE_URL"
cfg_target["key"] = f"{provider}_KEY"
# 如果 列表中的项目在 model_config 中,利用反射来设置对应项目
setattr(config,item,cfg_target)
setattr(config, item, cfg_target)
else:
logger.error(f"模型 {item} 在config中不存在请检查")
raise KeyError(f"模型 {item} 在config中不存在请检查")
@@ -265,19 +270,30 @@ class BotConfig:
config.MIN_TEXT_LENGTH = msg_config.get("min_text_length", config.MIN_TEXT_LENGTH)
config.MAX_CONTEXT_SIZE = msg_config.get("max_context_size", config.MAX_CONTEXT_SIZE)
config.emoji_chance = msg_config.get("emoji_chance", config.emoji_chance)
config.ban_words=msg_config.get("ban_words",config.ban_words)
config.ban_words = msg_config.get("ban_words", config.ban_words)
if config.INNER_VERSION in SpecifierSet(">=0.0.2"):
config.thinking_timeout = msg_config.get("thinking_timeout", config.thinking_timeout)
config.response_willing_amplifier = msg_config.get("response_willing_amplifier", config.response_willing_amplifier)
config.response_interested_rate_amplifier = msg_config.get("response_interested_rate_amplifier", config.response_interested_rate_amplifier)
config.response_willing_amplifier = msg_config.get(
"response_willing_amplifier", config.response_willing_amplifier
)
config.response_interested_rate_amplifier = msg_config.get(
"response_interested_rate_amplifier", config.response_interested_rate_amplifier
)
config.down_frequency_rate = msg_config.get("down_frequency_rate", config.down_frequency_rate)
if config.INNER_VERSION in SpecifierSet(">=0.0.6"):
config.ban_msgs_regex = msg_config.get("ban_msgs_regex", config.ban_msgs_regex)
def memory(parent: dict):
memory_config = parent["memory"]
config.build_memory_interval = memory_config.get("build_memory_interval", config.build_memory_interval)
config.forget_memory_interval = memory_config.get("forget_memory_interval", config.forget_memory_interval)
# 在版本 >= 0.0.4 时才处理新增的配置项
if config.INNER_VERSION in SpecifierSet(">=0.0.4"):
config.memory_ban_words = set(memory_config.get("memory_ban_words", []))
def mood(parent: dict):
mood_config = parent["mood"]
config.mood_update_interval = mood_config.get("mood_update_interval", config.mood_update_interval)
@@ -294,8 +310,12 @@ class BotConfig:
config.chinese_typo_enable = chinese_typo_config.get("enable", config.chinese_typo_enable)
config.chinese_typo_error_rate = chinese_typo_config.get("error_rate", config.chinese_typo_error_rate)
config.chinese_typo_min_freq = chinese_typo_config.get("min_freq", config.chinese_typo_min_freq)
config.chinese_typo_tone_error_rate = chinese_typo_config.get("tone_error_rate", config.chinese_typo_tone_error_rate)
config.chinese_typo_word_replace_rate = chinese_typo_config.get("word_replace_rate", config.chinese_typo_word_replace_rate)
config.chinese_typo_tone_error_rate = chinese_typo_config.get(
"tone_error_rate", config.chinese_typo_tone_error_rate
)
config.chinese_typo_word_replace_rate = chinese_typo_config.get(
"word_replace_rate", config.chinese_typo_word_replace_rate
)
def groups(parent: dict):
groups_config = parent["groups"]
@@ -307,6 +327,7 @@ class BotConfig:
others_config = parent["others"]
config.enable_advance_output = others_config.get("enable_advance_output", config.enable_advance_output)
config.enable_kuuki_read = others_config.get("enable_kuuki_read", config.enable_kuuki_read)
config.enable_debug_output = others_config.get("enable_debug_output", config.enable_debug_output)
# 版本表达式:>=1.0.0,<2.0.0
# 允许字段func: method, support: str, notice: str, necessary: bool
@@ -314,60 +335,19 @@ class BotConfig:
# 例如:"notice": "personality 将在 1.3.2 后被移除",那么在有效版本中的用户就会虽然可以
# 正常执行程序,但是会看到这条自定义提示
include_configs = {
"personality": {
"func": personality,
"support": ">=0.0.0"
},
"emoji": {
"func": emoji,
"support": ">=0.0.0"
},
"cq_code": {
"func": cq_code,
"support": ">=0.0.0"
},
"bot": {
"func": bot,
"support": ">=0.0.0"
},
"response": {
"func": response,
"support": ">=0.0.0"
},
"model": {
"func": model,
"support": ">=0.0.0"
},
"message": {
"func": message,
"support": ">=0.0.0"
},
"memory": {
"func": memory,
"support": ">=0.0.0"
},
"mood": {
"func": mood,
"support": ">=0.0.0"
},
"keywords_reaction": {
"func": keywords_reaction,
"support": ">=0.0.2",
"necessary": False
},
"chinese_typo": {
"func": chinese_typo,
"support": ">=0.0.3",
"necessary": False
},
"groups": {
"func": groups,
"support": ">=0.0.0"
},
"others": {
"func": others,
"support": ">=0.0.0"
}
"personality": {"func": personality, "support": ">=0.0.0"},
"emoji": {"func": emoji, "support": ">=0.0.0"},
"cq_code": {"func": cq_code, "support": ">=0.0.0"},
"bot": {"func": bot, "support": ">=0.0.0"},
"response": {"func": response, "support": ">=0.0.0"},
"model": {"func": model, "support": ">=0.0.0"},
"message": {"func": message, "support": ">=0.0.0"},
"memory": {"func": memory, "support": ">=0.0.0", "necessary": False},
"mood": {"func": mood, "support": ">=0.0.0"},
"keywords_reaction": {"func": keywords_reaction, "support": ">=0.0.2", "necessary": False},
"chinese_typo": {"func": chinese_typo, "support": ">=0.0.3", "necessary": False},
"groups": {"func": groups, "support": ">=0.0.0"},
"others": {"func": others, "support": ">=0.0.0"},
}
# 原地修改,将 字符串版本表达式 转换成 版本对象
@@ -379,10 +359,10 @@ class BotConfig:
with open(config_path, "rb") as f:
try:
toml_dict = tomli.load(f)
except(tomli.TOMLDecodeError) as e:
except tomli.TOMLDecodeError as e:
logger.critical(f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}")
exit(1)
# 获取配置文件版本
config.INNER_VERSION = cls.get_config_version(toml_dict)
@@ -394,7 +374,7 @@ class BotConfig:
# 检查配置文件版本是否在支持范围内
if config.INNER_VERSION in group_specifierset:
# 如果版本在支持范围内,检查是否存在通知
if 'notice' in include_configs[key]:
if "notice" in include_configs[key]:
logger.warning(include_configs[key]["notice"])
include_configs[key]["func"](toml_dict)
@@ -406,31 +386,32 @@ class BotConfig:
f"当前程序仅支持以下版本范围: {group_specifierset}"
)
raise InvalidVersion(f"当前程序仅支持以下版本范围: {group_specifierset}")
# 如果 necessary 项目存在,而且显式声明是 False进入特殊处理
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") == False:
elif "necessary" in include_configs[key] and include_configs[key].get("necessary") is False:
# 通过 pass 处理的项虽然直接忽略也是可以的,但是为了不增加理解困难,依然需要在这里显式处理
if key == "keywords_reaction":
pass
else:
# 如果用户根本没有需要的配置项,提示缺少配置
logger.error(f"配置文件中缺少必需的字段: '{key}'")
raise KeyError(f"配置文件中缺少必需的字段: '{key}'")
logger.success(f"成功加载配置文件: {config_path}")
return config
return config
# 获取配置文件路径
bot_config_floder_path = BotConfig.get_config_dir()
print(f"正在品鉴配置文件目录: {bot_config_floder_path}")
logger.debug(f"正在品鉴配置文件目录: {bot_config_floder_path}")
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
if os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
print(f"异常的新鲜,异常的美味: {bot_config_path}")
logger.debug(f"异常的新鲜,异常的美味: {bot_config_path}")
logger.info("使用bot配置文件")
else:
# 配置文件不存在
@@ -439,8 +420,10 @@ else:
global_config = BotConfig.load_config(config_path=bot_config_path)
if not global_config.enable_advance_output:
logger.remove()
pass
# 调试输出功能
if global_config.enable_debug_output:
logger.remove()
logger.add(sys.stdout, level="DEBUG")

View File

@@ -1,23 +1,24 @@
import base64
import html
import os
import time
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, List, Optional, Union
import requests
# 解析各种CQ码
# 包含CQ码类
import urllib3
from loguru import logger
from nonebot import get_driver
from urllib3.util import create_urllib3_context
from ..models.utils_model import LLM_request
from .config import global_config
from .mapper import emojimapper
from .utils_image import storage_emoji, storage_image
from .utils_user import get_user_nickname
from .message_base import Seg
from .utils_user import get_user_nickname,get_groupname
from .message_base import GroupInfo, UserInfo
driver = get_driver()
config = driver.config
@@ -35,65 +36,80 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False):
self.poolmanager = urllib3.poolmanager.PoolManager(
num_pools=connections, maxsize=maxsize,
block=block, ssl_context=self.ssl_context)
num_pools=connections,
maxsize=maxsize,
block=block,
ssl_context=self.ssl_context,
)
@dataclass
class CQCode:
"""
CQ码数据类用于存储和处理CQ码
属性:
type: CQ码类型'image', 'at', 'face'等)
params: CQ码的参数字典
raw_code: 原始CQ码字符串
translated_plain_text: 经过处理如AI翻译后的文本表示
translated_segments: 经过处理后的Seg对象列表
"""
type: str
params: Dict[str, str]
# raw_code: str
group_id: int
user_id: int
group_name: str = ""
user_nickname: str = ""
translated_plain_text: Optional[str] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
translated_segments: Optional[Union[Seg, List[Seg]]] = None
reply_message: Dict = None # 存储回复消息
image_base64: Optional[str] = None
_llm: Optional[LLM_request] = None
def __post_init__(self):
"""初始化LLM实例"""
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
pass
async def translate(self):
"""根据CQ码类型进行相应的翻译处理"""
if self.type == 'text':
self.translated_plain_text = self.params.get('text', '')
elif self.type == 'image':
if self.params.get('sub_type') == '0':
self.translated_plain_text = await self.translate_image()
def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == "text":
self.translated_segments = Seg(
type="text", data=self.params.get("text", "")
)
elif self.type == "image":
base64_data = self.translate_image()
if base64_data:
if self.params.get("sub_type") == "0":
self.translated_segments = Seg(type="image", data=base64_data)
else:
self.translated_segments = Seg(type="emoji", data=base64_data)
else:
self.translated_plain_text = await self.translate_emoji()
elif self.type == 'at':
user_nickname = get_user_nickname(self.params.get('qq', ''))
if user_nickname:
self.translated_plain_text = f"[@{user_nickname}]"
self.translated_segments = Seg(type="text", data="[图片]")
elif self.type == "at":
user_nickname = get_user_nickname(self.params.get("qq", ""))
self.translated_segments = Seg(
type="text", data=f"[@{user_nickname or '某人'}]"
)
elif self.type == "reply":
reply_segments = self.translate_reply()
if reply_segments:
self.translated_segments = Seg(type="seglist", data=reply_segments)
else:
self.translated_plain_text = "@某人"
elif self.type == 'reply':
self.translated_plain_text = await self.translate_reply()
elif self.type == 'face':
face_id = self.params.get('id', '')
# self.translated_plain_text = f"[表情{face_id}]"
self.translated_plain_text = f"[{emojimapper.get(int(face_id), '表情')}]"
elif self.type == 'forward':
self.translated_plain_text = await self.translate_forward()
self.translated_segments = Seg(type="text", data="[回复某人消息]")
elif self.type == "face":
face_id = self.params.get("id", "")
self.translated_segments = Seg(
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
)
elif self.type == "forward":
forward_segments = self.translate_forward()
if forward_segments:
self.translated_segments = Seg(type="seglist", data=forward_segments)
else:
self.translated_segments = Seg(type="text", data="[转发消息]")
else:
self.translated_plain_text = f"[{self.type}]"
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
def get_img(self):
'''
"""
headers = {
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
'Accept': 'image/*;q=0.8',
@@ -102,18 +118,18 @@ class CQCode:
'Cache-Control': 'no-cache',
'Pragma': 'no-cache'
}
'''
"""
# 腾讯专用请求头配置
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36',
'Accept': 'text/html, application/xhtml xml, */*',
'Accept-Encoding': 'gbk, GB2312',
'Accept-Language': 'zh-cn',
'Content-Type': 'application/x-www-form-urlencoded',
'Cache-Control': 'no-cache'
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
"Accept": "text/html, application/xhtml xml, */*",
"Accept-Encoding": "gbk, GB2312",
"Accept-Language": "zh-cn",
"Content-Type": "application/x-www-form-urlencoded",
"Cache-Control": "no-cache",
}
url = html.unescape(self.params['url'])
if not url.startswith(('http://', 'https://')):
url = html.unescape(self.params["url"])
if not url.startswith(("http://", "https://")):
return None
# 创建专用会话
@@ -129,247 +145,214 @@ class CQCode:
headers=headers,
timeout=15,
allow_redirects=True,
stream=True # 流式传输避免大内存问题
stream=True, # 流式传输避免大内存问题
)
# 腾讯服务器特殊状态码处理
if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url:
if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url:
return None
if response.status_code != 200:
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
# 验证内容类型
content_type = response.headers.get('Content-Type', '')
if not content_type.startswith('image/'):
content_type = response.headers.get("Content-Type", "")
if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}")
# 转换为Base64
image_base64 = base64.b64encode(response.content).decode('utf-8')
image_base64 = base64.b64encode(response.content).decode("utf-8")
self.image_base64 = image_base64
return image_base64
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
if retry == max_retries - 1:
print(f"\033[1;31m[致命错误]\033[0m 最终请求失败: {str(e)}")
time.sleep(1.5 ** retry) # 指数退避
logger.error(f"最终请求失败: {str(e)}")
time.sleep(1.5**retry) # 指数退避
except Exception as e:
print(f"\033[1;33m[未知错误]\033[0m {str(e)}")
except Exception:
logger.exception("[未知错误]")
return None
return None
async def translate_emoji(self) -> str:
"""处理表情包类型的CQ码"""
if 'url' not in self.params:
return '[表情包]'
base64_str = self.get_img()
if base64_str:
# 将 base64 字符串转换为字节类型
image_bytes = base64.b64decode(base64_str)
storage_emoji(image_bytes)
return await self.get_emoji_description(base64_str)
else:
return '[表情包]'
def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串"""
if "url" not in self.params:
return None
return self.get_img()
async def translate_image(self) -> str:
"""处理图片类型的CQ码区分普通图片和表情包"""
# 没有url直接返回默认文本
if 'url' not in self.params:
return '[图片]'
base64_str = self.get_img()
if base64_str:
image_bytes = base64.b64decode(base64_str)
storage_image(image_bytes)
return await self.get_image_description(base64_str)
else:
return '[图片]'
async def get_emoji_description(self, image_base64: str) -> str:
"""调用AI接口获取表情包描述"""
def translate_forward(self) -> Optional[List[Seg]]:
"""处理转发消息返回Seg列表"""
try:
prompt = "这是一个表情包请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[表情包:{description}]"
except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[表情包]"
if "content" not in self.params:
return None
async def get_image_description(self, image_base64: str) -> str:
"""调用AI接口获取普通图片描述"""
try:
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[图片:{description}]"
except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[图片]"
async def translate_forward(self) -> str:
"""处理转发消息"""
try:
if 'content' not in self.params:
return '[转发消息]'
# 解析content内容需要先反转义
content = self.unescape(self.params['content'])
# print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}")
# 将字符串形式的列表转换为Python对象
content = self.unescape(self.params["content"])
import ast
try:
messages = ast.literal_eval(content)
except ValueError as e:
print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}")
return '[转发消息]'
logger.error(f"解析转发消息内容失败: {str(e)}")
return None
# 处理每条消息
formatted_messages = []
formatted_segments = []
for msg in messages:
sender = msg.get('sender', {})
nickname = sender.get('card') or sender.get('nickname', '未知用户')
# 获取消息内容并使用Message类处理
raw_message = msg.get('raw_message', '')
message_array = msg.get('message', [])
sender = msg.get("sender", {})
nickname = sender.get("card") or sender.get("nickname", "未知用户")
raw_message = msg.get("raw_message", "")
message_array = msg.get("message", [])
if message_array and isinstance(message_array, list):
# 检查是否包含嵌套的转发消息
for message_part in message_array:
if message_part.get('type') == 'forward':
content = '[转发消息]'
if message_part.get("type") == "forward":
content_seg = Seg(type="text", data="[转发消息]")
break
else:
# 处理普通消息
if raw_message:
from .message import Message
message_obj = Message(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
)
await message_obj.initialize()
content = message_obj.processed_plain_text
else:
content = '[空消息]'
if raw_message:
from .message_cq import MessageRecvCQ
user_info=UserInfo(
platform='qq',
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info=GroupInfo(
platform='qq',
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0))
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
user_info=user_info,
raw_message=raw_message,
plain_text=raw_message,
group_info=group_info,
)
content_seg = Seg(
type="seglist", data=[message_obj.message_segment]
)
else:
content_seg = Seg(type="text", data="[空消息]")
else:
# 处理普通消息
if raw_message:
from .message import Message
message_obj = Message(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
from .message_cq import MessageRecvCQ
user_info=UserInfo(
platform='qq',
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info=GroupInfo(
platform='qq',
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0))
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
user_info=user_info,
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
group_info=group_info,
)
content_seg = Seg(
type="seglist", data=[message_obj.message_segment]
)
await message_obj.initialize()
content = message_obj.processed_plain_text
else:
content = '[空消息]'
content_seg = Seg(type="text", data="[空消息]")
formatted_msg = f"{nickname}: {content}"
formatted_messages.append(formatted_msg)
formatted_segments.append(Seg(type="text", data=f"{nickname}: "))
formatted_segments.append(content_seg)
formatted_segments.append(Seg(type="text", data="\n"))
# 合并所有消息
combined_messages = '\n'.join(formatted_messages)
print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}")
return f"[转发消息:\n{combined_messages}]"
return formatted_segments
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}")
return '[转发消息]'
logger.error(f"处理转发消息失败: {str(e)}")
return None
async def translate_reply(self) -> str:
"""处理回复类型的CQ码"""
def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ
# 创建Message对象
from .message import Message
if self.reply_message == None:
# print(f"\033[1;31m[错误]\033[0m 回复消息为空")
return '[回复某人消息]'
if self.reply_message is None:
return None
if self.reply_message.sender.user_id:
message_obj = Message(
user_id=self.reply_message.sender.user_id,
message_obj = MessageRecvCQ(
user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.nickname),
message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message),
group_id=self.group_id
group_info=GroupInfo(group_id=self.reply_message.group_id),
)
await message_obj.initialize()
if message_obj.user_id == global_config.BOT_QQ:
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
else:
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
segments = []
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
segments.append(
Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
)
)
else:
segments.append(
Seg(
type="text",
data=f"[回复 {self.reply_message.sender.nickname} 的消息: ",
)
)
segments.append(Seg(type="seglist", data=[message_obj.message_segment]))
segments.append(Seg(type="text", data="]"))
return segments
else:
print("\033[1;31m[错误]\033[0m 回复消息的sender.user_id为空")
return '[回复某人消息]'
return None
@staticmethod
def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符"""
return text.replace('&#44;', ',') \
.replace('&#91;', '[') \
.replace('&#93;', ']') \
.replace('&amp;', '&')
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = abs_path.replace('&', '&amp;') \
.replace('[', '&#91;') \
.replace(']', '&#93;') \
.replace(',', '&#44;')
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
return (
text.replace("&#44;", ",")
.replace("&#91;", "[")
.replace("&#93;", "]")
.replace("&amp;", "&")
)
class CQCode_tool:
@staticmethod
async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode:
"""
将CQ码字典转换为CQCode对象
Args:
cq_code: CQ码字典
msg: MessageCQ对象
reply: 回复消息的字典(可选)
Returns:
CQCode对象
"""
# 处理字典形式的CQ码
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
cq_type = cq_code.get('type', 'text')
cq_type = cq_code.get("type", "text")
params = {}
if cq_type == 'text':
params['text'] = cq_code.get('data', {}).get('text', '')
if cq_type == "text":
params["text"] = cq_code.get("data", {}).get("text", "")
else:
params = cq_code.get('data', {})
params = cq_code.get("data", {})
instance = CQCode(
type=cq_type,
params=params,
group_id=0,
user_id=0,
group_info=msg.message_info.group_info,
user_info=msg.message_info.user_info,
reply_message=reply
)
# 进行翻译处理
await instance.translate()
instance.translate()
return instance
@staticmethod
@@ -383,5 +366,64 @@ class CQCode_tool:
"""
return f"[CQ:reply,id={message_id}]"
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = (
abs_path.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@staticmethod
def create_emoji_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
@staticmethod
def create_image_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
cq_code_tool = CQCode_tool()

View File

@@ -1,9 +1,11 @@
import asyncio
import base64
import hashlib
import os
import random
import time
import traceback
from typing import Optional
from typing import Optional, Tuple
from loguru import logger
from nonebot import get_driver
@@ -11,34 +13,37 @@ from nonebot import get_driver
from ...common.database import Database
from ..chat.config import global_config
from ..chat.utils import get_embedding
from ..chat.utils_image import image_path_to_base64
from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
image_manager = ImageManager()
class EmojiManager:
_instance = None
EMOJI_DIR = "data/emoji" # 表情包存储目录
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
self.db = Database.get_instance()
self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,temperature=0.8) #更高的温度更少的token后续可以根据情绪来调整温度
self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60,
temperature=0.8) # 更高的温度更少的token后续可以根据情绪来调整温度
def _ensure_emoji_dir(self):
"""确保表情存储目录存在"""
os.makedirs(self.EMOJI_DIR, exist_ok=True)
def initialize(self):
"""初始化数据库连接和表情目录"""
if not self._initialized:
@@ -49,16 +54,16 @@ class EmojiManager:
self._initialized = True
# 启动时执行一次完整性检查
self.check_emoji_file_integrity()
except Exception as e:
logger.error(f"初始化表情管理器失败: {str(e)}")
except Exception:
logger.exception("初始化表情管理器失败")
def _ensure_db(self):
"""确保数据库已初始化"""
if not self._initialized:
self.initialize()
if not self._initialized:
raise RuntimeError("EmojiManager not initialized")
def _ensure_emoji_collection(self):
"""确保emoji集合存在并创建索引
@@ -74,9 +79,8 @@ class EmojiManager:
if 'emoji' not in self.db.db.list_collection_names():
self.db.db.create_collection('emoji')
self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('tags', 1)])
self.db.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str):
"""记录表情使用次数"""
try:
@@ -88,7 +92,7 @@ class EmojiManager:
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
async def get_emoji_for_text(self, text: str) -> Optional[str]:
async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str,str]]:
"""根据文本内容获取相关表情包
Args:
text: 输入文本
@@ -102,9 +106,9 @@ class EmojiManager:
"""
try:
self._ensure_db()
# 获取文本的embedding
text_for_search= await self._get_kimoji_for_text(text)
text_for_search = await self._get_kimoji_for_text(text)
if not text_for_search:
logger.error("无法获取文本的情绪")
return None
@@ -112,15 +116,15 @@ class EmojiManager:
if not text_embedding:
logger.error("无法获取文本的embedding")
return None
try:
# 获取所有表情包
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
if not all_emojis:
logger.warning("数据库中没有任何表情包")
return None
# 计算余弦相似度并排序
def cosine_similarity(v1, v2):
if not v1 or not v2:
@@ -131,25 +135,25 @@ class EmojiManager:
if norm_v1 == 0 or norm_v2 == 0:
return 0
return dot_product / (norm_v1 * norm_v2)
# 计算所有表情包与输入文本的相似度
emoji_similarities = [
(emoji, cosine_similarity(text_embedding, emoji.get('embedding', [])))
for emoji in all_emojis
]
# 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包
top_3_emojis = emoji_similarities[:3]
top_10_emojis = emoji_similarities[:10 if len(emoji_similarities) > 10 else len(emoji_similarities)]
if not top_3_emojis:
if not top_10_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_3_emojis)
selected_emoji, similarity = random.choice(top_10_emojis)
if selected_emoji and 'path' in selected_emoji:
# 更新使用次数
@@ -157,57 +161,61 @@ class EmojiManager:
{'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}}
)
logger.success(f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})")
logger.success(
f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})")
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
return selected_emoji['path'],"[ %s ]" % selected_emoji.get('description', '无描述')
return selected_emoji['path'], "[ %s ]" % selected_emoji.get('description', '无描述')
except Exception as search_error:
logger.error(f"搜索表情包失败: {str(search_error)}")
return None
return None
except Exception as e:
logger.error(f"获取表情包失败: {str(e)}")
return None
async def _get_emoji_description(self, image_base64: str) -> str:
"""获取表情包的标签"""
async def _get_emoji_discription(self, image_base64: str) -> str:
"""获取表情包的标签使用image_manager的描述生成功能"""
try:
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
logger.debug(f"输出描述: {content}")
return content
# 使用image_manager获取描述去掉前后的方括号和"表情包:"前缀
description = await image_manager.get_emoji_description(image_base64)
# 去掉[表情包xxx]的格式,只保留描述内容
description = description.strip('[]').replace('表情包:', '')
return description
except Exception as e:
logger.error(f"获取标签失败: {str(e)}")
return None
async def _check_emoji(self, image_base64: str) -> str:
try:
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
logger.debug(f"输出描述: {content}")
return content
except Exception as e:
logger.error(f"获取标签失败: {str(e)}")
return None
async def _get_kimoji_for_text(self, text:str):
async def _get_kimoji_for_text(self, text: str):
try:
prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。'
content, _ = await self.llm_emotion_judge.generate_response_async(prompt)
content, _ = await self.llm_emotion_judge.generate_response_async(prompt,temperature=1.5)
logger.info(f"输出描述: {content}")
return content
except Exception as e:
logger.error(f"获取标签失败: {str(e)}")
return None
async def scan_new_emojis(self):
"""扫描新的表情包"""
try:
@@ -215,62 +223,122 @@ class EmojiManager:
os.makedirs(emoji_dir, exist_ok=True)
# 获取所有支持的图片文件
files_to_process = [f for f in os.listdir(emoji_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
files_to_process = [f for f in os.listdir(emoji_dir) if
f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
for filename in files_to_process:
image_path = os.path.join(emoji_dir, filename)
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
if existing_emoji:
continue
# 压缩图片并获取base64编码
# 获取图片的base64编码和哈希值
image_base64 = image_path_to_base64(image_path)
if image_base64 is None:
os.remove(image_path)
continue
# 获取表情包的描述
description = await self._get_emoji_description(image_base64)
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
description = None
if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription')
# 检查是否在images集合中存在
existing_image = image_manager.db.db.images.find_one({'hash': image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步已存在的表情包到images集合: {filename}")
continue
# 检查是否在images集合中已有描述
existing_description = image_manager._get_description_from_db(image_hash, 'emoji')
if existing_description:
description = existing_description
else:
# 获取表情包的描述
description = await self._get_emoji_discription(image_base64)
if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64)
if '' not in check:
os.remove(image_path)
logger.info(f"描述: {description}")
logger.info(f"描述: {description}")
logger.info(f"其不满足过滤规则,被剔除 {check}")
continue
logger.info(f"check通过 {check}")
embedding = await get_embedding(description)
if description is not None:
embedding = await get_embedding(description)
if description is not None:
embedding = await get_embedding(description)
# 准备数据库记录
emoji_record = {
'filename': filename,
'path': image_path,
'embedding':embedding,
'description': description,
'embedding': embedding,
'discription': description,
'hash': image_hash,
'timestamp': int(time.time())
}
# 保存到数据库
# 保存到emoji数据库
self.db.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}")
# 保存到images数据库
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步保存到images集合: {filename}")
else:
logger.warning(f"跳过表情包: {filename}")
except Exception as e:
logger.error(f"扫描表情包失败: {str(e)}")
logger.error(traceback.format_exc())
except Exception:
logger.exception("扫描表情包失败")
async def _periodic_scan(self, interval_MINS: int = 10):
"""定期扫描新表情包"""
while True:
print("\033[1;36m[表情包]\033[0m 开始扫描新表情包...")
logger.info("开始扫描新表情包...")
await self.scan_new_emojis()
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
def check_emoji_file_integrity(self):
"""检查表情包文件完整性
如果文件已被删除,则从数据库中移除对应记录
@@ -281,7 +349,7 @@ class EmojiManager:
all_emojis = list(self.db.db.emoji.find())
removed_count = 0
total_count = len(all_emojis)
for emoji in all_emojis:
try:
if 'path' not in emoji:
@@ -289,27 +357,27 @@ class EmojiManager:
self.db.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
# 检查文件是否存在
if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录
result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
if result.deleted_count > 0:
logger.success(f"成功删除数据库记录: {emoji['_id']}")
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1
else:
logger.error(f"删除数据库记录失败: {emoji['_id']}")
except Exception as item_error:
logger.error(f"处理表情包记录时出错: {str(item_error)}")
continue
# 验证清理结果
remaining_count = self.db.db.emoji.count_documents({})
if removed_count > 0:
@@ -317,7 +385,7 @@ class EmojiManager:
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")
else:
logger.info(f"已检查 {total_count} 个表情包记录")
except Exception as e:
logger.error(f"检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc())
@@ -328,6 +396,8 @@ class EmojiManager:
await asyncio.sleep(interval_MINS * 60)
# 创建全局单例
emoji_manager = EmojiManager()
emoji_manager = EmojiManager()

View File

@@ -3,11 +3,12 @@ import time
from typing import List, Optional, Tuple, Union
from nonebot import get_driver
from loguru import logger
from ...common.database import Database
from ..models.utils_model import LLM_request
from .config import global_config
from .message import Message
from .message import MessageRecv, MessageThinking, Message
from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager
from .utils import process_llm_response
@@ -18,58 +19,89 @@ config = driver.config
class ResponseGenerator:
def __init__(self):
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000,stream=True)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000)
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000)
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000)
self.model_r1 = LLM_request(
model=global_config.llm_reasoning,
temperature=0.7,
max_tokens=1000,
stream=True,
)
self.model_v3 = LLM_request(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000
)
self.model_r1_distill = LLM_request(
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
)
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.db = Database.get_instance()
self.current_model_type = 'r1' # 默认使用 R1
self.current_model_type = "r1" # 默认使用 R1
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
async def generate_response(
self, message: MessageThinking
) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型
rand = random.random()
if rand < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = 'r1'
self.current_model_type = "r1"
current_model = self.model_r1
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
self.current_model_type = 'v3'
elif (
rand
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
):
self.current_model_type = "v3"
current_model = self.model_v3
else:
self.current_model_type = 'r1_distill'
self.current_model_type = "r1_distill"
current_model = self.model_r1_distill
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++")
model_response = await self._generate_response_with_model(message, current_model)
raw_content=model_response
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
model_response = await self._generate_response_with_model(
message, current_model
)
raw_content = model_response
# print(f"raw_content: {raw_content}")
# print(f"model_response: {model_response}")
if model_response:
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
model_response = await self._process_response(model_response)
if model_response:
return model_response, raw_content
return None, raw_content
return model_response ,raw_content
return None,raw_content
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
async def _generate_response_with_model(
self, message: MessageThinking, model: LLM_request
) -> Optional[str]:
"""使用指定的模型生成回复"""
sender_name = message.user_nickname or f"用户{message.user_id}"
if message.user_cardname:
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}"
sender_name = (
message.chat_stream.user_info.user_nickname
or f"用户{message.chat_stream.user_info.user_id}"
)
if message.chat_stream.user_info.user_cardname:
sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}"
# 获取关系值
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
relationship_value = (
relationship_manager.get_relationship(
message.chat_stream
).relationship_value
if relationship_manager.get_relationship(message.chat_stream)
else 0.0
)
if relationship_value != 0.0:
# print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
pass
# 构建prompt
prompt, prompt_check = await prompt_builder._build_prompt(
message_txt=message.processed_plain_text,
sender_name=sender_name,
relationship_value=relationship_value,
group_id=message.group_id
stream_id=message.chat_stream.stream_id,
)
# 读空气模块 简化逻辑,先停用
@@ -92,10 +124,10 @@ class ResponseGenerator:
# 生成回复
try:
content, reasoning_content = await model.generate_response(prompt)
except Exception as e:
print(f"生成回复时出错: {e}")
except Exception:
logger.exception("生成回复时出错")
return None
# 保存到数据库
self._save_to_db(
message=message,
@@ -107,54 +139,73 @@ class ResponseGenerator:
reasoning_content=reasoning_content,
# reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
)
return content
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
content: str, reasoning_content: str,):
def _save_to_db(
self,
message: MessageRecv,
sender_name: str,
prompt: str,
prompt_check: str,
content: str,
reasoning_content: str,
):
"""保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one({
'time': time.time(),
'group_id': message.group_id,
'user': sender_name,
'message': message.processed_plain_text,
'model': self.current_model_type,
# 'reasoning_check': reasoning_content_check,
# 'response_check': content_check,
'reasoning': reasoning_content,
'response': content,
'prompt': prompt,
'prompt_check': prompt_check
})
self.db.db.reasoning_logs.insert_one(
{
"time": time.time(),
"chat_id": message.chat_stream.stream_id,
"user": sender_name,
"message": message.processed_plain_text,
"model": self.current_model_type,
# 'reasoning_check': reasoning_content_check,
# 'response_check': content_check,
"reasoning": reasoning_content,
"response": content,
"prompt": prompt,
"prompt_check": prompt_check,
}
)
async def _get_emotion_tags(self, content: str) -> List[str]:
"""提取情感标签"""
try:
prompt = f'''请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
prompt = f"""请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
只输出标签就好,不要输出其他内容:
内容:{content}
输出:
'''
"""
content, _ = await self.model_v25.generate_response(prompt)
content=content.strip()
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']:
content = content.strip()
if content in [
"happy",
"angry",
"sad",
"surprised",
"disgusted",
"fearful",
"neutral",
]:
return [content]
else:
return ["neutral"]
except Exception as e:
print(f"获取情感标签时出错: {e}")
return ["neutral"]
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
"""处理响应内容,返回处理后的内容和情感标签"""
if not content:
return None, []
processed_response = process_llm_response(content)
# print(f"得到了处理后的llm返回{processed_response}")
return processed_response
@@ -172,7 +223,7 @@ class InitiativeMessageGenerate:
prompt_builder._build_initiative_prompt_select(message.group_id)
)
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
print(f"[DEBUG] {content_select} {reasoning}")
logger.debug(f"{content_select} {reasoning}")
topics_list = [dot[0] for dot in dots_for_select]
if content_select:
if content_select in topics_list:
@@ -185,12 +236,12 @@ class InitiativeMessageGenerate:
select_dot[1], prompt_template
)
content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
print(f"[DEBUG] {content_check} {reasoning_check}")
logger.info(f"{content_check} {reasoning_check}")
if "yes" not in content_check.lower():
return None
prompt = prompt_builder._build_initiative_prompt(
select_dot, prompt_template, memory
)
content, reasoning = self.model_r1.generate_response_async(prompt)
print(f"[DEBUG] {content} {reasoning}")
logger.debug(f"[DEBUG] {content} {reasoning}")
return content

View File

@@ -1,14 +1,16 @@
import time
import html
import re
import json
from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional
from typing import Dict, List, Optional
import urllib3
from loguru import logger
from .cq_code import CQCode, cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
Message = ForwardRef('Message') # 添加这行
from .utils_image import image_manager
from .message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -16,216 +18,383 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message(MessageBase):
chat_stream: ChatStream=None
reply: Optional['Message'] = None
detailed_plain_text: str = ""
processed_plain_text: str = ""
def __init__(
self,
message_id: str,
time: int,
chat_stream: ChatStream,
user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None,
detailed_plain_text: str = "",
processed_plain_text: str = "",
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=chat_stream.platform,
message_id=message_id,
time=time,
group_info=chat_stream.group_info,
user_info=user_info
)
# 调用父类初始化
super().__init__(
message_info=message_info,
message_segment=message_segment,
raw_message=None
)
self.chat_stream = chat_stream
# 文本处理相关属性
self.processed_plain_text = processed_plain_text
self.detailed_plain_text = detailed_plain_text
# 回复消息
self.reply = reply
@dataclass
class Message:
"""消息数据类"""
message_id: int = None
time: float = None
group_id: int = None
group_name: str = None # 群名称
user_id: int = None
user_nickname: str = None # 用户昵称
user_cardname: str = None # 用户群昵称
raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本
reply_message: Dict = None # 存储 回复的 源消息
# 延迟初始化字段
_initialized: bool = False
message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本
# 状态标志
is_emoji: bool = False
has_emoji: bool = False
translate_cq: bool = True
async def initialize(self):
"""显式异步初始化方法(必须调用)"""
if self._initialized:
return
# 异步获取补充信息
self.group_name = self.group_name or get_groupname(self.group_id)
self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
# 消息解析
if self.raw_message:
if not isinstance(self,Message_Sending):
self.message_segments = await self.parse_message_segments(self.raw_message)
self.processed_plain_text = ' '.join(
seg.translated_plain_text
for seg in self.message_segments
)
# 构建详细文本
if self.time is None:
self.time = int(time.time())
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
name = (
f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
if self.user_cardname
else f"{self.user_nickname or f'用户{self.user_id}'}"
)
if isinstance(self,Message_Sending) and self.is_emoji:
self.detailed_plain_text = f"[{time_str}] {name}: {self.detailed_plain_text}\n"
else:
self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
self._initialized = True
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
async def parse_message_segments(self, message: str) -> List[CQCode]:
def __init__(self, message_dict: Dict):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
将消息解析为片段列表包括纯文本和CQ码
返回的列表中每个元素都是字典,包含:
- cq_code_list:分割出的聊天对象包括文本和CQ码
- trans_list:翻译后的对象列表
"""
# print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
cq_code_dict_list = []
trans_list = []
start = 0
while True:
# 查找下一个CQ码的开始位置
cq_start = message.find('[CQ:', start)
#如果没有cq码直接返回文本内容
if cq_start == -1:
# 如果没有找到更多CQ码添加剩余文本
if start < len(message):
text = message[start:].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
break
# 添加CQ码前的文本
if cq_start > start:
text = message[start:cq_start].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
# 查找CQ码的结束位置
cq_end = message.find(']', cq_start)
if cq_end == -1:
# CQ码未闭合作为普通文本处理
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
#将cq_code解析成字典
cq_code_dict_list.append(parse_cq_code(cq_code))
# 更新start位置到当前CQ码之后
start = cq_end + 1
# print(f"\033[1;34m[调试信息]\033[0m 提取的消息对象:列表: {cq_code_dict_list}")
#判定是否是表情包消息,以及是否含有表情包
if len(cq_code_dict_list) == 1 and cq_code_dict_list[0]['type'] == 'image':
self.is_emoji = True
self.has_emoji_emoji = True
else:
for segment in cq_code_dict_list:
if segment['type'] == 'image' and segment['data'].get('sub_type') == '1':
self.has_emoji_emoji = True
break
#翻译作为字典的CQ码
for _code_item in cq_code_dict_list:
message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message)
trans_list.append(message_obj)
return trans_list
self.message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {}))
class Message_Thinking:
"""消息思考类"""
def __init__(self, message: Message,message_id: str):
# 复制原始消息的基本属性
self.group_id = message.group_id
self.user_id = message.user_id
self.user_nickname = message.user_nickname
self.user_cardname = message.user_cardname
self.group_name = message.group_name
message_segment = message_dict.get('message_segment', {})
if message_segment.get('data','') == '[json]':
# 提取json消息中的展示信息
pattern = r'\[CQ:json,data=(?P<json_data>.+?)\]'
match = re.search(pattern, message_dict.get('raw_message',''))
raw_json = html.unescape(match.group('json_data'))
try:
json_message = json.loads(raw_json)
except json.JSONDecodeError:
json_message = {}
message_segment['data'] = json_message.get('prompt','')
self.message_segment = Seg.from_dict(message_dict.get('message_segment', {}))
self.raw_message = message_dict.get('raw_message')
self.message_id = message_id
# 处理消息内容
self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji=False
def update_chat_stream(self,chat_stream:ChatStream):
self.chat_stream=chat_stream
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
# 思考状态相关属性
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
self.is_emoji=True
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname!=''
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None
):
# 调用父类初始化
super().__init__(
message_id=message_id,
time=int(time.time()),
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
reply=reply
)
# 处理状态相关属性
self.thinking_start_time = int(time.time())
self.thinking_time = 0
self.interupt=False
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
@dataclass
class Message_Sending(Message):
"""发送中的消息类"""
thinking_start_time: float = None # 思考开始时间
thinking_time: float = None # 思考时间
reply_message_id: int = None # 存储 回复的 源消息ID
is_head: bool = False # 是否是头部消息
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
def update_thinking_time(self) -> float:
"""更新思考时间"""
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
return self.thinking_time
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
elif seg.type == 'at':
return f"[@{seg.data}]"
elif seg.type == 'reply':
if self.reply and hasattr(self.reply, 'processed_plain_text'):
return f"[回复:{self.reply.processed_plain_text}]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname != ''
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageThinking(MessageProcessBase):
"""思考状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
reply: Optional['MessageRecv'] = None
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=None, # 思考状态不需要消息段
reply=reply
)
# 思考状态特有属性
self.interrupt = False
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Seg,
reply: Optional['MessageRecv'] = None,
is_head: bool = False,
is_emoji: bool = False
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=reply
)
# 发送状态特有属性
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
self.is_emoji = is_emoji
def set_reply(self, reply: Optional['MessageRecv']) -> None:
"""设置回复消息"""
if reply:
self.reply = reply
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(type='seglist', data=[
Seg(type='reply', data=reply.message_info.message_id),
self.message_segment
])
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
@classmethod
def from_thinking(
cls,
thinking: MessageThinking,
message_segment: Seg,
is_head: bool = False,
is_emoji: bool = False
) -> 'MessageSending':
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
chat_stream=thinking.chat_stream,
message_segment=message_segment,
bot_user_info=thinking.message_info.user_info,
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji
)
def to_dict(self):
ret= super().to_dict()
ret['message_info']['user_info']=self.chat_stream.user_info.to_dict()
return ret
@dataclass
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, group_id: int, user_id: int, message_id: str):
self.group_id = group_id
self.user_id = user_id
def __init__(self, chat_stream: ChatStream, message_id: str):
self.chat_stream = chat_stream
self.message_id = message_id
self.messages: List[Message_Sending] = [] # 修改类型标注
self.messages: List[MessageSending] = []
self.time = round(time.time(), 2)
def add_message(self, message: Message_Sending) -> None:
"""添加消息到集合只接受Message_Sending类型"""
if not isinstance(message, Message_Sending):
raise TypeError("MessageSet只能添加Message_Sending类型的消息")
def add_message(self, message: MessageSending) -> None:
"""添加消息到集合"""
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
# 按时间排序
self.messages.sort(key=lambda x: x.time)
self.messages.sort(key=lambda x: x.message_info.time)
def get_message_by_index(self, index: int) -> Optional[Message_Sending]:
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
if 0 <= index < len(self.messages):
return self.messages[index]
return None
def get_message_by_time(self, target_time: float) -> Optional[Message_Sending]:
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
if not self.messages:
return None
# 使用二分查找找到最接近的消息
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].time < target_time:
if self.messages[mid].message_info.time < target_time:
left = mid + 1
else:
right = mid
return self.messages[left]
def clear_messages(self) -> None:
"""清空所有消息"""
self.messages.clear()
def remove_message(self, message: Message_Sending) -> bool:
def remove_message(self, message: MessageSending) -> bool:
"""移除指定消息"""
if message in self.messages:
self.messages.remove(message)

View File

@@ -0,0 +1,186 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Dict
@dataclass
class Seg:
"""消息片段类,用于表示消息的不同部分
Attributes:
type: 片段类型,可以是 'text''image''seglist'
data: 片段的具体内容
- 对于 text 类型data 是字符串
- 对于 image 类型data 是 base64 字符串
- 对于 seglist 类型data 是 Seg 列表
translated_data: 经过翻译处理的数据(可选)
"""
type: str
data: Union[str, List['Seg']]
# def __init__(self, type: str, data: Union[str, List['Seg']],):
# """初始化实例,确保字典和属性同步"""
# # 先初始化字典
# self.type = type
# self.data = data
@classmethod
def from_dict(cls, data: Dict) -> 'Seg':
"""从字典创建Seg实例"""
type=data.get('type')
data=data.get('data')
if type == 'seglist':
data = [Seg.from_dict(seg) for seg in data]
return cls(
type=type,
data=data
)
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {'type': self.type}
if self.type == 'seglist':
result['data'] = [seg.to_dict() for seg in self.data]
else:
result['data'] = self.data
return result
@dataclass
class GroupInfo:
"""群组信息类"""
platform: Optional[str] = None
group_id: Optional[int] = None
group_name: Optional[str] = None # 群名称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'GroupInfo':
"""从字典创建GroupInfo实例
Args:
data: 包含必要字段的字典
Returns:
GroupInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
group_id=data.get('group_id'),
group_name=data.get('group_name',None)
)
@dataclass
class UserInfo:
"""用户信息类"""
platform: Optional[str] = None
user_id: Optional[int] = None
user_nickname: Optional[str] = None # 用户昵称
user_cardname: Optional[str] = None # 用户群昵称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'UserInfo':
"""从字典创建UserInfo实例
Args:
data: 包含必要字段的字典
Returns:
UserInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
user_id=data.get('user_id'),
user_nickname=data.get('user_nickname',None),
user_cardname=data.get('user_cardname',None)
)
@dataclass
class BaseMessageInfo:
"""消息信息类"""
platform: Optional[str] = None
message_id: Union[str,int,None] = None
time: Optional[int] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {}
for field, value in asdict(self).items():
if value is not None:
if isinstance(value, (GroupInfo, UserInfo)):
result[field] = value.to_dict()
else:
result[field] = value
return result
@classmethod
def from_dict(cls, data: Dict) -> 'BaseMessageInfo':
"""从字典创建BaseMessageInfo实例
Args:
data: 包含必要字段的字典
Returns:
BaseMessageInfo: 新的实例
"""
group_info = GroupInfo(**data.get('group_info', {}))
user_info = UserInfo(**data.get('user_info', {}))
return cls(
platform=data.get('platform'),
message_id=data.get('message_id'),
time=data.get('time'),
group_info=group_info,
user_info=user_info
)
@dataclass
class MessageBase:
"""消息类"""
message_info: BaseMessageInfo
message_segment: Seg
raw_message: Optional[str] = None # 原始消息包含未解析的cq码
def to_dict(self) -> Dict:
"""转换为字典格式
Returns:
Dict: 包含所有非None字段的字典其中
- message_info: 转换为字典格式
- message_segment: 转换为字典格式
- raw_message: 如果存在则包含
"""
result = {
'message_info': self.message_info.to_dict(),
'message_segment': self.message_segment.to_dict()
}
if self.raw_message is not None:
result['raw_message'] = self.raw_message
return result
@classmethod
def from_dict(cls, data: Dict) -> 'MessageBase':
"""从字典创建MessageBase实例
Args:
data: 包含必要字段的字典
Returns:
MessageBase: 新的实例
"""
message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {}))
raw_message = data.get('raw_message',None)
return cls(
message_info=message_info,
message_segment=message_segment,
raw_message=raw_message
)

View File

@@ -0,0 +1,169 @@
import time
from dataclasses import dataclass
from typing import Dict, Optional
import urllib3
from .cq_code import cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageCQ(MessageBase):
"""QQ消息基类继承自MessageBase
最小必要参数:
- message_id: 消息ID
- user_id: 发送者/接收者ID
- platform: 平台标识(默认为"qq"
"""
def __init__(
self,
message_id: int,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
platform: str = "qq"
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform,
message_id=message_id,
time=int(time.time()),
group_info=group_info,
user_info=user_info
)
# 调用父类初始化message_segment 由子类设置
super().__init__(
message_info=message_info,
message_segment=None,
raw_message=None
)
@dataclass
class MessageRecvCQ(MessageCQ):
"""QQ接收消息类用于解析raw_message到Seg对象"""
def __init__(
self,
message_id: int,
user_info: UserInfo,
raw_message: str,
group_info: Optional[GroupInfo] = None,
platform: str = "qq",
reply_message: Optional[Dict] = None,
):
# 调用父类初始化
super().__init__(message_id, user_info, group_info, platform)
if group_info.group_name is None:
group_info.group_name = get_groupname(group_info.group_id)
# 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message)
self.raw_message = raw_message
def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
"""解析消息内容为Seg对象"""
cq_code_dict_list = []
segments = []
start = 0
while True:
cq_start = message.find('[CQ:', start)
if cq_start == -1:
if start < len(message):
text = message[start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
if cq_start > start:
text = message[start:cq_start].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
cq_end = message.find(']', cq_start)
if cq_end == -1:
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
cq_code_dict_list.append(parse_cq_code(cq_code))
start = cq_end + 1
# 转换CQ码为Seg对象
for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message)
if message_obj.translated_segments:
segments.append(message_obj.translated_segments)
# 如果只有一个segment直接返回
if len(segments) == 1:
return segments[0]
# 否则返回seglist类型的Seg
return Seg(type='seglist', data=segments)
def to_dict(self) -> Dict:
"""转换为字典格式,包含所有必要信息"""
base_dict = super().to_dict()
return base_dict
@dataclass
class MessageSendCQ(MessageCQ):
"""QQ发送消息类用于将Seg对象转换为raw_message"""
def __init__(
self,
data: Dict
):
# 调用父类初始化
message_info = BaseMessageInfo.from_dict(data.get('message_info', {}))
message_segment = Seg.from_dict(data.get('message_segment', {}))
super().__init__(
message_info.message_id,
message_info.user_info,
message_info.group_info if message_info.group_info else None,
message_info.platform
)
self.message_segment = message_segment
self.raw_message = self._generate_raw_message()
def _generate_raw_message(self, ) -> str:
"""将Seg对象转换为raw_message"""
segments = []
# 处理消息段
if self.message_segment.type == 'seglist':
for seg in self.message_segment.data:
segments.append(self._seg_to_cq_code(seg))
else:
segments.append(self._seg_to_cq_code(self.message_segment))
return ''.join(segments)
def _seg_to_cq_code(self, seg: Seg) -> str:
"""将单个Seg对象转换为CQ码字符串"""
if seg.type == 'text':
return str(seg.data)
elif seg.type == 'image':
return cq_code_tool.create_image_cq_base64(seg.data)
elif seg.type == 'emoji':
return cq_code_tool.create_emoji_cq_base64(seg.data)
elif seg.type == 'at':
return f"[CQ:at,qq={seg.data}]"
elif seg.type == 'reply':
return cq_code_tool.create_reply_cq(int(seg.data))
else:
return f"[{seg.data}]"

View File

@@ -2,224 +2,212 @@ import asyncio
import time
from typing import Dict, List, Optional, Union
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool
from .message import Message, Message_Sending, Message_Thinking, MessageSet
from .message_cq import MessageSendCQ
from .message import MessageSending, MessageThinking, MessageSet
from .storage import MessageStorage
from .utils import calculate_typing_time
from .config import global_config
class Message_Sender:
"""发送器"""
def __init__(self):
self.message_interval = (0.5, 1) # 消息间隔时间范围(秒)
self.last_send_time = 0
self._current_bot = None
def set_bot(self, bot: Bot):
"""设置当前bot实例"""
self._current_bot = bot
async def send_group_message(
self,
group_id: int,
send_text: str,
auto_escape: bool = False,
reply_message_id: int = None,
at_user_id: int = None
) -> None:
if not self._current_bot:
raise RuntimeError("Bot未设置请先调用set_bot方法设置bot实例")
message = send_text
# 如果需要回复
if reply_message_id:
reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
message = reply_cq + message
# 如果需要at
# if at_user_id:
# at_cq = cq_code_tool.create_at_cq(at_user_id)
# message = at_cq + " " + message
typing_time = calculate_typing_time(message)
if typing_time > 10:
typing_time = 10
await asyncio.sleep(typing_time)
# 发送消息
try:
await self._current_bot.send_group_msg(
group_id=group_id,
message=message,
auto_escape=auto_escape
async def send_message(
self,
message: MessageSending,
) -> None:
"""发送消息"""
if isinstance(message, MessageSending):
message_json = message.to_dict()
message_send=MessageSendCQ(
data=message_json
)
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
except Exception as e:
print(f"发生错误 {e}")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
if message_send.message_info.group_info:
try:
await self._current_bot.send_group_msg(
group_id=message.message_info.group_info.group_id,
message=message_send.raw_message,
auto_escape=False
)
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
except Exception as e:
logger.error(f"[调试] 发生错误 {e}")
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
else:
try:
await self._current_bot.send_private_msg(
user_id=message.message_info.user_info.user_id,
message=message_send.raw_message,
auto_escape=False
)
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
except Exception as e:
logger.error(f"发生错误 {e}")
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
class MessageContainer:
"""单个的发送/思考消息容器"""
def __init__(self, group_id: int, max_size: int = 100):
self.group_id = group_id
"""单个聊天流的发送/思考消息容器"""
def __init__(self, chat_id: str, max_size: int = 100):
self.chat_id = chat_id
self.max_size = max_size
self.messages = []
self.last_send_time = 0
self.thinking_timeout = 20 # 思考超时时间(秒)
def get_timeout_messages(self) -> List[Message_Sending]:
def get_timeout_messages(self) -> List[MessageSending]:
"""获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序"""
current_time = time.time()
timeout_messages = []
for msg in self.messages:
if isinstance(msg, Message_Sending):
if isinstance(msg, MessageSending):
if current_time - msg.thinking_start_time > self.thinking_timeout:
timeout_messages.append(msg)
# 按thinking_start_time排序时间早的在前面
timeout_messages.sort(key=lambda x: x.thinking_start_time)
return timeout_messages
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]:
def get_earliest_message(self) -> Optional[Union[MessageThinking, MessageSending]]:
"""获取thinking_start_time最早的消息对象"""
if not self.messages:
return None
earliest_time = float('inf')
earliest_message = None
for msg in self.messages:
for msg in self.messages:
msg_time = msg.thinking_start_time
if msg_time < earliest_time:
earliest_time = msg_time
earliest_message = msg
earliest_message = msg
return earliest_message
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None:
def add_message(self, message: Union[MessageThinking, MessageSending]) -> None:
"""添加消息到队列"""
# print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
if isinstance(message, MessageSet):
for single_message in message.messages:
self.messages.append(single_message)
else:
self.messages.append(message)
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool:
def remove_message(self, message: Union[MessageThinking, MessageSending]) -> bool:
"""移除消息如果消息存在则返回True否则返回False"""
try:
if message in self.messages:
self.messages.remove(message)
return True
return False
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 移除消息时发生错误: {e}")
except Exception:
logger.exception("移除消息时发生错误")
return False
def has_messages(self) -> bool:
"""检查是否有待发送的消息"""
return bool(self.messages)
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]:
def get_all_messages(self) -> List[Union[MessageSending, MessageThinking]]:
"""获取所有消息"""
return list(self.messages)
class MessageManager:
"""管理所有的消息容器"""
"""管理所有聊天流的消息容器"""
def __init__(self):
self.containers: Dict[int, MessageContainer] = {}
self.containers: Dict[str, MessageContainer] = {} # chat_id -> MessageContainer
self.storage = MessageStorage()
self._running = True
def get_container(self, group_id: int) -> MessageContainer:
"""获取或创建的消息容器"""
if group_id not in self.containers:
self.containers[group_id] = MessageContainer(group_id)
return self.containers[group_id]
def get_container(self, chat_id: str) -> MessageContainer:
"""获取或创建聊天流的消息容器"""
if chat_id not in self.containers:
self.containers[chat_id] = MessageContainer(chat_id)
return self.containers[chat_id]
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None:
container = self.get_container(message.group_id)
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
chat_stream = message.chat_stream
if not chat_stream:
raise ValueError("无法找到对应的聊天流")
container = self.get_container(chat_stream.stream_id)
container.add_message(message)
async def process_group_messages(self, group_id: int):
"""处理消息"""
# if int(time.time() / 3) == time.time() / 3:
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
container = self.get_container(group_id)
async def process_chat_messages(self, chat_id: str):
"""处理聊天流消息"""
container = self.get_container(chat_id)
if container.has_messages():
#最早的对象,可能是思考消息,也可能是发送消息
message_earliest = container.get_earliest_message() #一个message_thinking or message_sending
# print(f"处理有message的容器chat_id: {chat_id}")
message_earliest = container.get_earliest_message()
#如果是思考消息
if isinstance(message_earliest, Message_Thinking):
#优先等待这条消息
if isinstance(message_earliest, MessageThinking):
message_earliest.update_thinking_time()
thinking_time = message_earliest.thinking_time
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}\033[K\r", end='', flush=True)
print(f"消息正在思考中,已思考{int(thinking_time)}\r", end='', flush=True)
# 检查是否超时
if thinking_time > global_config.thinking_timeout:
print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息")
logger.warning(f"消息思考超时({thinking_time}秒),移除该消息")
container.remove_message(message_earliest)
else:# 如果不是message_thinking就只能是message_sending
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
#直接发,等什么呢
if message_earliest.is_head and message_earliest.update_thinking_time() >30:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id)
else:
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
await message_sender.send_message(message_earliest.set_reply())
else:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False)
#移除消息
if message_earliest.is_emoji:
message_earliest.processed_plain_text = "[表情包]"
await self.storage.store_message(message_earliest, None)
await message_sender.send_message(message_earliest)
await message_earliest.process()
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
await self.storage.store_message(message_earliest, message_earliest.chat_stream,None)
container.remove_message(message_earliest)
#获取并处理超时消息
message_timeout = container.get_timeout_messages() #也许是一堆message_sending
message_timeout = container.get_timeout_messages()
if message_timeout:
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息")
logger.warning(f"发现{len(message_timeout)}条超时消息")
for msg in message_timeout:
if msg == message_earliest:
continue # 跳过已经处理过的消息
continue
try:
#发送
if msg.is_head and msg.update_thinking_time() >30:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
if msg.is_head and msg.update_thinking_time() > 30:
await message_sender.send_message(msg.set_reply())
else:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False)
await message_sender.send_message(msg)
#如果是表情包,则替换为"[表情包]"
if msg.is_emoji:
msg.processed_plain_text = "[表情包]"
await self.storage.store_message(msg, None)
# if msg.is_emoji:
# msg.processed_plain_text = "[表情包]"
await msg.process()
await self.storage.store_message(msg,msg.chat_stream, None)
# 安全地移除消息
if not container.remove_message(msg):
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息")
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理超时消息时发生错误: {e}")
logger.warning("尝试删除不存在的消息")
except Exception:
logger.exception("处理超时消息时发生错误")
continue
async def start_processor(self):
"""启动消息处理器"""
while self._running:
await asyncio.sleep(1)
tasks = []
for group_id in self.containers.keys():
tasks.append(self.process_group_messages(group_id))
for chat_id in self.containers.keys():
tasks.append(self.process_chat_messages(chat_id))
await asyncio.gather(*tasks)
# 创建全局消息管理器实例
message_manager = MessageManager()
# 创建全局发送器实例

View File

@@ -1,6 +1,7 @@
import random
import time
from typing import Optional
from loguru import logger
from ...common.database import Database
from ..memory_system.memory import hippocampus, memory_graph
@@ -8,6 +9,7 @@ from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule
from .config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text
from .chat_stream import chat_manager
class PromptBuilder:
@@ -22,7 +24,7 @@ class PromptBuilder:
message_txt: str,
sender_name: str = "某人",
relationship_value: float = 0.0,
group_id: Optional[int] = None) -> tuple[str, str]:
stream_id: Optional[int] = None) -> tuple[str, str]:
"""构建prompt
Args:
@@ -33,57 +35,62 @@ class PromptBuilder:
Returns:
str: 构建好的prompt
"""
#先禁用关系
"""
# 先禁用关系
if 0 > 30:
relation_prompt = "关系特别特别好,你很喜欢喜欢他"
relation_prompt_2 = "热情发言或者回复"
elif 0 <-20:
elif 0 < -20:
relation_prompt = "关系很差,你很讨厌他"
relation_prompt_2 = "骂他"
else:
relation_prompt = "关系一般"
relation_prompt_2 = "发言或者回复"
#开始构建prompt
#心情
# 开始构建prompt
# 心情
mood_manager = MoodManager.get_instance()
mood_prompt = mood_manager.get_prompt()
#日程构建
# 日程构建
current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task()
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
#知识构建
# 知识构建
start_time = time.time()
prompt_info = ''
promt_info_prompt = ''
prompt_info = await self.get_prompt_info(message_txt,threshold=0.5)
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
if prompt_info:
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
prompt_info = f'''你有以下这些[知识]{prompt_info}请你记住上面的[
知识],之后可能会用到-'''
end_time = time.time()
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}")
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}")
# 获取聊天上下文
chat_in_group=True
chat_talking_prompt = ''
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_stream=chat_manager.get_stream(stream_id)
if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
else:
chat_in_group=False
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 使用新的记忆获取方法
memory_prompt = ''
start_time = time.time()
# 调用 hippocampus 的 get_relevant_memories 方法
relevant_memories = await hippocampus.get_relevant_memories(
text=message_txt,
@@ -91,67 +98,64 @@ class PromptBuilder:
similarity_threshold=0.4,
max_memory_num=5
)
if relevant_memories:
# 格式化记忆内容
memory_items = []
for memory in relevant_memories:
memory_items.append(f"关于「{memory['topic']}」的记忆:{memory['content']}")
memory_prompt = "看到这些聊天,你想起来:\n" + "\n".join(memory_items) + "\n"
# 打印调试信息
print("\n\033[1;32m[记忆检索]\033[0m 找到以下相关记忆:")
logger.debug("[记忆检索]找到以下相关记忆:")
for memory in relevant_memories:
print(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}")
logger.debug(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}")
end_time = time.time()
print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}")
#激活prompt构建
logger.info(f"回忆耗时: {(end_time - start_time):.3f}")
# 激活prompt构建
activate_prompt = ''
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
#检测机器人相关词汇,改为关键词检测与反应功能了,提取到全局配置中
# bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
# is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
# if is_bot:
# is_bot_prompt = '有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认'
# else:
# is_bot_prompt = ''
if chat_in_group:
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
else:
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
# 关键词检测与反应
keywords_reaction_prompt = ''
for rule in global_config.keywords_reaction_rules:
if rule.get("enable", False):
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
print(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
keywords_reaction_prompt += rule.get("reaction", "") + ''
#人格选择
personality=global_config.PROMPT_PERSONALITY
probability_1 = global_config.PERSONALITY_1
probability_2 = global_config.PERSONALITY_2
probability_3 = global_config.PERSONALITY_3
prompt_personality = ''
prompt_personality = f'{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{"/".join(global_config.BOT_ALIAS_NAMES)}'
personality_choice = random.random()
if chat_in_group:
prompt_in_group=f"你正在浏览{chat_stream.platform}"
else:
prompt_in_group=f"你正在{chat_stream.platform}上和{sender_name}私聊"
if personality_choice < probability_1: # 第一种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
prompt_personality += f'''{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{keywords_reaction_prompt}
请注意把握群里的聊天内容,不要刻意突出自身学科背景,不要回复的太有条理,可以有个性。'''
elif personality_choice < probability_1 + probability_2: # 第二种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}, 你正在浏览qq群{promt_info_prompt},
prompt_personality += f'''{personality[1]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt}
请你表达自己的见解和观点。可以有个性。'''
else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt},
prompt_personality += f'''{personality[2]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复,请表现你自己的见解,不要一昧迎合,尽量简短一些。{keywords_reaction_prompt}
请你表达自己的见解和观点。可以有个性。'''
#中文高手(新加的好玩功能)
# 中文高手(新加的好玩功能)
prompt_ger = ''
if random.random() < 0.04:
prompt_ger += '你喜欢用倒装句'
@@ -159,23 +163,23 @@ class PromptBuilder:
prompt_ger += '你喜欢用反问句'
if random.random() < 0.01:
prompt_ger += '你喜欢用文言文'
#额外信息要求
extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
#合并prompt
# 额外信息要求
extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
# 合并prompt
prompt = ""
prompt += f"{prompt_info}\n"
prompt += f"{prompt_date}\n"
prompt += f"{chat_talking_prompt}\n"
prompt += f"{chat_talking_prompt}\n"
prompt += f"{prompt_personality}\n"
prompt += f"{prompt_ger}\n"
prompt += f"{extra_info}\n"
'''读空气prompt处理'''
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt += f"{extra_info}\n"
'''读空气prompt处理'''
activate_prompt_check = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt_personality_check = ''
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
extra_check_info = f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
if personality_choice < probability_1: # 第一种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
elif personality_choice < probability_1 + probability_2: # 第二种人格
@@ -183,34 +187,36 @@ class PromptBuilder:
else: # 第三种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
return prompt,prompt_check_if_response
def _build_initiative_prompt_select(self,group_id):
prompt_check_if_response = f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
return prompt, prompt_check_if_response
def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1):
current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task()
bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
chat_talking_prompt = ''
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
limit=global_config.MAX_CONTEXT_SIZE,
combine=True)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 获取主动发言的话题
all_nodes=memory_graph.dots
all_nodes=filter(lambda dot:len(dot[1]['memory_items'])>3,all_nodes)
nodes_for_select=random.sample(all_nodes,5)
topics=[info[0] for info in nodes_for_select]
infos=[info[1] for info in nodes_for_select]
all_nodes = memory_graph.dots
all_nodes = filter(lambda dot: len(dot[1]['memory_items']) > 3, all_nodes)
nodes_for_select = random.sample(all_nodes, 5)
topics = [info[0] for info in nodes_for_select]
infos = [info[1] for info in nodes_for_select]
#激活prompt构建
# 激活prompt构建
activate_prompt = ''
activate_prompt = "以上是群里正在进行的聊天。"
personality=global_config.PROMPT_PERSONALITY
personality = global_config.PROMPT_PERSONALITY
prompt_personality = ''
personality_choice = random.random()
if personality_choice < probability_1: # 第一种人格
@@ -219,32 +225,31 @@ class PromptBuilder:
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}'''
else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}'''
topics_str=','.join(f"\"{topics}\"")
prompt_for_select=f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
prompt_initiative_select=f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
prompt_regular=f"{prompt_date}\n{prompt_personality}"
return prompt_initiative_select,nodes_for_select,prompt_regular
def _build_initiative_prompt_check(self,selected_node,prompt_regular):
memory=random.sample(selected_node['memory_items'],3)
memory='\n'.join(memory)
prompt_for_check=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check,memory
def _build_initiative_prompt(self,selected_node,prompt_regular,memory):
prompt_for_initiative=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
topics_str = ','.join(f"\"{topics}\"")
prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
prompt_regular = f"{prompt_date}\n{prompt_personality}"
return prompt_initiative_select, nodes_for_select, prompt_regular
def _build_initiative_prompt_check(self, selected_node, prompt_regular):
memory = random.sample(selected_node['memory_items'], 3)
memory = '\n'.join(memory)
prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check, memory
def _build_initiative_prompt(self, selected_node, prompt_regular, memory):
prompt_for_initiative = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
return prompt_for_initiative
async def get_prompt_info(self,message:str,threshold:float):
async def get_prompt_info(self, message: str, threshold: float):
related_info = ''
print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message)
related_info += self.get_info_from_db(embedding,threshold=threshold)
related_info += self.get_info_from_db(embedding, threshold=threshold)
return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
@@ -305,14 +310,15 @@ class PromptBuilder:
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results:
return ''
# 返回所有找到的内容,用换行分隔
return '\n'.join(str(result['content']) for result in results)
prompt_builder = PromptBuilder()
prompt_builder = PromptBuilder()

View File

@@ -1,115 +1,177 @@
import asyncio
from typing import Optional
from loguru import logger
from ...common.database import Database
from .message_base import UserInfo
from .chat_stream import ChatStream
class Impression:
traits: str = None
called: str = None
know_time: float = None
relationship_value: float = None
class Relationship:
user_id: int = None
# impression: Impression = None
# group_id: int = None
# group_name: str = None
platform: str = None
gender: str = None
age: int = None
nickname: str = None
relationship_value: float = None
saved = False
def __init__(self, user_id: int, data=None, **kwargs):
if isinstance(data, dict):
# 如果输入是字典,使用字典解析
self.user_id = data.get('user_id')
self.gender = data.get('gender')
self.age = data.get('age')
self.nickname = data.get('nickname')
self.relationship_value = data.get('relationship_value', 0.0)
self.saved = data.get('saved', False)
else:
# 如果是直接传入属性值
self.user_id = kwargs.get('user_id')
self.gender = kwargs.get('gender')
self.age = kwargs.get('age')
self.nickname = kwargs.get('nickname')
self.relationship_value = kwargs.get('relationship_value', 0.0)
self.saved = kwargs.get('saved', False)
def __init__(self, chat:ChatStream=None,data:dict=None):
self.user_id=chat.user_info.user_id if chat else data.get('user_id',0)
self.platform=chat.platform if chat else data.get('platform','')
self.nickname=chat.user_info.user_nickname if chat else data.get('nickname','')
self.relationship_value=data.get('relationship_value',0) if data else 0
self.age=data.get('age',0) if data else 0
self.gender=data.get('gender','') if data else ''
class RelationshipManager:
def __init__(self):
self.relationships: dict[int, Relationship] = {}
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
async def update_relationship(self, user_id: int, data=None, **kwargs):
async def update_relationship(self,
chat_stream:ChatStream,
data: dict = None,
**kwargs) -> Optional[Relationship]:
"""更新或创建关系
Args:
chat_stream: 聊天流对象
data: 字典格式的数据(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if chat_stream.user_info is not None:
user_id = chat_stream.user_info.user_id
platform = chat_stream.user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
# 检查是否在内存中已存在
relationship = self.relationships.get(user_id)
relationship = self.relationships.get(key)
if relationship:
# 如果存在,更新现有对象
if isinstance(data, dict):
for key, value in data.items():
if hasattr(relationship, key) and value is not None:
setattr(relationship, key, value)
else:
for key, value in kwargs.items():
if hasattr(relationship, key) and value is not None:
setattr(relationship, key, value)
for k, value in data.items():
if hasattr(relationship, k) and value is not None:
setattr(relationship, k, value)
else:
# 如果不存在,创建新对象
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id, **kwargs)
self.relationships[user_id] = relationship
# 更新 id_name_nickname_table
# self.id_name_nickname_table[user_id] = [relationship.nickname] # 别称设置为空列表
if chat_stream.user_info is not None:
relationship = Relationship(chat=chat_stream, **kwargs)
else:
raise ValueError("必须提供user_id或user_info")
self.relationships[key] = relationship
# 保存到数据库
await self.storage_relationship(relationship)
relationship.saved = True
return relationship
async def update_relationship_value(self, user_id: int, **kwargs):
async def update_relationship_value(self,
chat_stream:ChatStream,
**kwargs) -> Optional[Relationship]:
"""更新关系值
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
user_info = chat_stream.user_info
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
# 检查是否在内存中已存在
relationship = self.relationships.get(user_id)
relationship = self.relationships.get(key)
if relationship:
for key, value in kwargs.items():
if key == 'relationship_value':
for k, value in kwargs.items():
if k == 'relationship_value':
relationship.relationship_value += value
await self.storage_relationship(relationship)
relationship.saved = True
return relationship
else:
print(f"\033[1;31m[关系管理]\033[0m 用户 {user_id} 不存在,无法更新")
# 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(chat_stream=chat_stream, **kwargs)
logger.warning(f"[关系管理] 用户 {user_id}({platform}) 不存在,无法更新")
return None
def get_relationship(self, user_id: int) -> Optional[Relationship]:
"""获取用户关系对象"""
if user_id in self.relationships:
return self.relationships[user_id]
def get_relationship(self,
chat_stream:ChatStream) -> Optional[Relationship]:
"""获取用户关系对象
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
user_info = chat_stream.user_info
platform = chat_stream.user_info.platform or 'qq'
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key]
else:
return 0
async def load_relationship(self, data: dict) -> Relationship:
"""从数据库加载或创建新的关系对象"""
rela = Relationship(user_id=data['user_id'], data=data)
"""从数据库加载或创建新的关系对象"""
# 确保data中有platform字段如果没有则默认为'qq'
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data)
rela.saved = True
self.relationships[rela.user_id] = rela
key = (rela.user_id, rela.platform)
self.relationships[key] = rela
return rela
async def load_all_relationships(self):
"""加载所有关系对象"""
db = Database.get_instance()
all_relationships = db.db.relationships.find({})
for data in all_relationships:
await self.load_relationship(data)
async def _start_relationship_manager(self):
"""每5分钟自动保存一次关系数据"""
db = Database.get_instance()
@@ -117,39 +179,37 @@ class RelationshipManager:
all_relationships = db.db.relationships.find({})
# 依次加载每条记录
for data in all_relationships:
user_id = data['user_id']
relationship = await self.load_relationship(data)
self.relationships[user_id] = relationship
print(f"\033[1;32m[关系管理]\033[0m 已加载 {len(self.relationships)} 条关系记录")
await self.load_relationship(data)
logger.debug(f"[关系管理] 已加载 {len(self.relationships)} 条关系记录")
while True:
print("\033[1;32m[关系管理]\033[0m 正在自动保存关系")
logger.debug("正在自动保存关系")
await asyncio.sleep(300) # 等待300秒(5分钟)
await self._save_all_relationships()
async def _save_all_relationships(self):
"""将所有关系数据保存到数据库"""
"""将所有关系数据保存到数据库"""
# 保存所有关系数据
for userid, relationship in self.relationships.items():
for (userid, platform), relationship in self.relationships.items():
if not relationship.saved:
relationship.saved = True
await self.storage_relationship(relationship)
async def storage_relationship(self,relationship: Relationship):
"""
将关系记录存储到数据库中
"""
async def storage_relationship(self, relationship: Relationship):
"""将关系记录存储到数据库中"""
user_id = relationship.user_id
platform = relationship.platform
nickname = relationship.nickname
relationship_value = relationship.relationship_value
gender = relationship.gender
age = relationship.age
saved = relationship.saved
db = Database.get_instance()
db.db.relationships.update_one(
{'user_id': user_id},
{'user_id': user_id, 'platform': platform},
{'$set': {
'platform': platform,
'nickname': nickname,
'relationship_value': relationship_value,
'gender': gender,
@@ -159,14 +219,38 @@ class RelationshipManager:
upsert=True
)
def get_name(self, user_id: int) -> str:
def get_name(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> str:
"""获取用户昵称
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
str: 用户昵称
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 确保user_id是整数类型
user_id = int(user_id)
if user_id in self.relationships:
return self.relationships[user_id].nickname
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key].nickname
elif user_info is not None:
return user_info.user_nickname or user_info.user_cardname or "某人"
else:
return "某人"
relationship_manager = RelationshipManager()
relationship_manager = RelationshipManager()

View File

@@ -1,49 +1,30 @@
from typing import Optional
from typing import Optional, Union
from ...common.database import Database
from .message import Message
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from loguru import logger
class MessageStorage:
def __init__(self):
self.db = Database.get_instance()
async def store_message(self, message: Message, topic: Optional[str] = None) -> None:
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库"""
try:
if not message.is_emoji:
message_data = {
"group_id": message.group_id,
"user_id": message.user_id,
"message_id": message.message_id,
"raw_message": message.raw_message,
"plain_text": message.plain_text,
message_data = {
"message_id": message.message_info.message_id,
"time": message.message_info.time,
"chat_id":chat_stream.stream_id,
"chat_info": chat_stream.to_dict(),
"user_info": message.message_info.user_info.to_dict(),
"processed_plain_text": message.processed_plain_text,
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic,
"detailed_plain_text": message.detailed_plain_text,
}
else:
message_data = {
"group_id": message.group_id,
"user_id": message.user_id,
"message_id": message.message_id,
"raw_message": message.raw_message,
"plain_text": message.plain_text,
"processed_plain_text": '[表情包]',
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic,
"detailed_plain_text": message.detailed_plain_text,
}
self.db.db.messages.insert_one(message_data)
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 存储消息失败: {e}")
except Exception:
logger.exception("存储消息失败")
# 如果需要其他存储相关的函数,可以在这里添加
# 如果需要其他存储相关的函数,可以在这里添加

View File

@@ -4,9 +4,11 @@ from nonebot import get_driver
from ..models.utils_model import LLM_request
from .config import global_config
from loguru import logger
driver = get_driver()
config = driver.config
config = driver.config
class TopicIdentifier:
def __init__(self):
@@ -23,19 +25,20 @@ class TopicIdentifier:
# 使用 LLM_request 类进行请求
topic, _ = await self.llm_topic_judge.generate_response(prompt)
if not topic:
print("\033[1;31m[错误]\033[0m LLM API 返回为空")
logger.error("LLM API 返回为空")
return None
# 直接在这里处理主题解析
if not topic or topic == "无主题":
return None
# 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()]
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}")
logger.info(f"主题: {topic_list}")
return topic_list if topic_list else None
topic_identifier = TopicIdentifier()
topic_identifier = TopicIdentifier()

View File

@@ -7,65 +7,44 @@ from typing import Dict, List
import jieba
import numpy as np
from nonebot import get_driver
from loguru import logger
from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config
from .message import Message
from .message import MessageRecv,Message
from .message_base import UserInfo
from .chat_stream import ChatStream
from ..moods.moods import MoodManager
driver = get_driver()
config = driver.config
def combine_messages(messages: List[Message]) -> str:
"""将消息列表组合成格式化的字符串
Args:
messages: Message对象列表
Returns:
str: 格式化后的消息字符串
"""
result = ""
for message in messages:
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
name = message.user_nickname or f"用户{message.user_id}"
content = message.processed_plain_text or message.plain_text
result += f"[{time_str}] {name}: {content}\n"
return result
def db_message_to_str(message_dict: Dict) -> str:
print(f"message_dict: {message_dict}")
logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try:
name = "[(%s)%s]%s" % (
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
except:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n"
print(f"result: {result}")
logger.debug(f"result: {result}")
return result
def is_mentioned_bot_in_message(message: Message) -> bool:
def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
nicknames = global_config.BOT_ALIAS_NAMES
for keyword in keywords:
if keyword in message.processed_plain_text:
return True
return False
def is_mentioned_bot_in_txt(message: str) -> bool:
"""检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME]
for keyword in keywords:
if keyword in message:
for nickname in nicknames:
if nickname in message.processed_plain_text:
return True
return False
@@ -98,40 +77,45 @@ def calculate_information_content(text):
def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
chat_text = ''
"""从数据库中获取最接近指定时间戳的聊天记录
Args:
db: 数据库实例
length: 要获取的消息数量
timestamp: 时间戳
Returns:
list: 消息记录列表,每个记录包含时间和文本信息
"""
chat_records = []
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record and closest_record.get('memorized', 0) < 4:
if closest_record:
closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_id = closest_record['chat_id'] # 获取chat_id
# 获取该时间戳之后的length条消息保持相同的chat_id
chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id}
{
"time": {"$gt": closest_time},
"chat_id": chat_id # 添加chat_id过滤
}
).sort('time', 1).limit(length))
# 更新每条消息的memorized属性
# 转换记录格式
formatted_records = []
for record in chat_records:
# 检查当前记录的memorized值
current_memorized = record.get('memorized', 0)
if current_memorized > 3:
# print(f"消息已读取3次跳过")
return ''
# 更新memorized值
db.db.messages.update_one(
{"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}}
)
chat_text += record["detailed_plain_text"]
return chat_text
# print(f"消息已读取3次跳过")
return ''
formatted_records.append({
'time': record["time"],
'chat_id': record["chat_id"],
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
})
return formatted_records
return []
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
Args:
@@ -145,38 +129,31 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
{"chat_id": chat_id},
).sort("time", -1).limit(limit))
if not recent_messages:
return []
# 转换为 Message对象列表
from .message import Message
message_objects = []
for msg_data in recent_messages:
try:
chat_info=msg_data.get("chat_info",{})
chat_stream=ChatStream.from_dict(chat_info)
user_info=msg_data.get("user_info",{})
user_info=UserInfo.from_dict(user_info)
msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"],
chat_stream=chat_stream,
time=msg_data["time"],
user_info=user_info,
processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id
detailed_plain_text=msg_data.get("detailed_plain_text", "")
)
await msg.initialize()
message_objects.append(msg)
except KeyError:
print("[WARNING] 数据库中存在无效的消息")
logger.warning("数据库中存在无效的消息")
continue
# 按时间正序排列
@@ -184,13 +161,14 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
return message_objects
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.db.messages.find(
{"group_id": group_id},
{"chat_id": chat_stream_id},
{
"time": 1, # 返回时间字段
"user_id": 1, # 返回用户ID字段
"user_nickname": 1, # 返回用户昵称字段
"chat_id":1,
"chat_info":1,
"user_info": 1,
"message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1 # 返回处理后的文本字段
}
@@ -292,11 +270,10 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
sentence = sentence.replace('', ' ').replace(',', ' ')
sentences_done.append(sentence)
print(f"处理后的句子: {sentences_done}")
logger.info(f"处理后的句子: {sentences_done}")
return sentences_done
def random_remove_punctuation(text: str) -> str:
"""随机处理标点符号,模拟人类打字习惯
@@ -324,11 +301,10 @@ def random_remove_punctuation(text: str) -> str:
return result
def process_llm_response(text: str) -> List[str]:
# processed_response = process_text_with_typos(content)
if len(text) > 200:
print(f"回复过长 ({len(text)} 字符),返回默认回复")
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
return ['懒得说']
# 处理长消息
typo_generator = ChineseTypoGenerator(
@@ -348,9 +324,9 @@ def process_llm_response(text: str) -> List[str]:
else:
sentences.append(sentence)
# 检查分割后的消息数量是否过多超过3条
if len(sentences) > 5:
print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f'{global_config.BOT_NICKNAME}不知道哦']
return sentences
@@ -372,15 +348,15 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.4, english_
mood_arousal = mood_manager.current_mood.arousal
# 映射到0.5到2倍的速度系数
typing_speed_multiplier = 1.5 ** mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
chinese_time *= 1/typing_speed_multiplier
english_time *= 1/typing_speed_multiplier
chinese_time *= 1 / typing_speed_multiplier
english_time *= 1 / typing_speed_multiplier
# 计算中文字符数
chinese_chars = sum(1 for char in input_string if '\u4e00' <= char <= '\u9fff')
# 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1:
return chinese_time * 3 + 0.3 # 加上回车时间
# 正常计算所有字符的输入时间
total_time = 0.0
for char in input_string:

View File

@@ -1,296 +1,353 @@
import base64
import io
import os
import time
import zlib # 用于 CRC32
import aiohttp
import hashlib
from typing import Optional, Union
from loguru import logger
from nonebot import get_driver
from PIL import Image
from ...common.database import Database
from ..chat.config import global_config
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
"""
压缩base64格式的图片到指定大小单位KB并在数据库中记录图片信息
Args:
base64_data: base64编码的图片数据
max_size: 最大文件大小KB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保图片目录存在
images_dir = "data/images"
os.makedirs(images_dir, exist_ok=True)
# 连接数据库
db = Database(
host=config.mongodb_host,
port=int(config.mongodb_port),
db_name=config.database_name,
username=config.mongodb_username,
password=config.mongodb_password,
auth_source=config.mongodb_auth_source
)
# 检查是否已存在相同哈希值的图片
collection = db.db['images']
existing_image = collection.find_one({'hash': hash_value})
if existing_image:
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 如果是动图,直接返回原图
if getattr(img, 'is_animated', False):
return base64_data
# 计算当前大小KB
current_size = len(image_data) / 1024
# 如果已经小于目标大小,直接使用原图
if current_size <= max_size:
compressed_data = image_data
else:
# 压缩逻辑
# 先缩放到50%
new_width = int(img.width * 0.5)
new_height = int(img.height * 0.5)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 如果缩放后的最大边长仍然大于400继续缩放
max_dimension = 400
max_current = max(new_width, new_height)
if max_current > max_dimension:
ratio = max_dimension / max_current
new_width = int(new_width * ratio)
new_height = int(new_height * ratio)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 转换为RGB模式去除透明通道
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# 使用固定质量参数压缩
output = io.BytesIO()
img.save(output, format='JPEG', quality=85, optimize=True)
compressed_data = output.getvalue()
# 生成文件名(使用时间戳和哈希值确保唯一性)
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
image_path = os.path.join(images_dir, filename)
# 保存文件
with open(image_path, "wb") as f:
f.write(compressed_data)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
try:
# 准备数据库记录
image_record = {
'filename': filename,
'path': image_path,
'size': len(compressed_data) / 1024,
'timestamp': timestamp,
'width': img.width,
'height': img.height,
'description': '',
'tags': [],
'type': 'image',
'hash': hash_value
}
# 保存记录
collection.insert_one(image_record)
print("\033[1;32m[成功]\033[0m 保存图片记录到数据库")
except Exception as db_error:
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
# 将压缩后的数据转换为base64
compressed_base64 = base64.b64encode(compressed_data).decode('utf-8')
return compressed_base64
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
import traceback
print(traceback.format_exc())
return base64_data
def storage_emoji(image_data: bytes) -> bytes:
"""
存储表情包到本地文件夹
Args:
image_data: 图片字节数据
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns:
bytes: 原始图片数据
"""
if not global_config.EMOJI_SAVE:
return image_data
try:
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保表情包目录存在
emoji_dir = "data/emoji"
os.makedirs(emoji_dir, exist_ok=True)
# 检查是否已存在相同哈希值的文件
for filename in os.listdir(emoji_dir):
if hash_value in filename:
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
emoji_path = os.path.join(emoji_dir, filename)
# 直接保存原始文件
with open(emoji_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存表情包到: {emoji_path}")
return image_data
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存表情包失败: {str(e)}")
return image_data
class ImageManager:
_instance = None
IMAGE_DIR = "data" # 图像存储根目录
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.db = Database.get_instance()
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if 'images' not in self.db.db.list_collection_names():
self.db.db.create_collection('images')
# 创建索引
self.db.db.images.create_index([('hash', 1)], unique=True)
self.db.db.images.create_index([('url', 1)])
self.db.db.images.create_index([('path', 1)])
def storage_image(image_data: bytes) -> bytes:
"""
存储图片到本地文件夹
Args:
image_data: 图片字节数据
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns:
bytes: 原始图片数据
"""
try:
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保表情包目录存在
image_dir = "data/image"
os.makedirs(image_dir, exist_ok=True)
# 检查是否已存在相同哈希值的文件
for filename in os.listdir(image_dir):
if hash_value in filename:
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
image_path = os.path.join(image_dir, filename)
# 直接保存原始文件
with open(image_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
return image_data
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
return image_data
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.db.list_collection_names():
self.db.db.create_collection('image_descriptions')
# 创建索引
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)])
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
Args:
image_hash: 图片哈希值
description_type: 描述类型 ('emoji''image')
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result= self.db.db.image_descriptions.find_one({
'hash': image_hash,
'type': description_type
})
return result['description'] if result else None
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
Args:
image_hash: 图片哈希值
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
self.db.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type},
{
'$set': {
'description': description,
'timestamp': int(time.time())
}
},
upsert=True
)
async def save_image(self,
image_data: Union[str, bytes],
url: str = None,
description: str = None,
is_base64: bool = False) -> Optional[str]:
"""保存图像
Args:
image_data: 图像数据(base64字符串或字节)
url: 图像URL
description: 图像描述
is_base64: image_data是否为base64格式
Returns:
str: 保存后的文件路径,失败返回None
"""
try:
# 转换为字节格式
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return None
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return None
# 计算哈希值
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查重
existing = self.db.db.images.find_one({'hash': image_hash})
if existing:
return existing['path']
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'url': url,
'description': description,
'timestamp': timestamp
}
self.db.db.images.insert_one(image_doc)
return file_path
except Exception as e:
logger.error(f"保存图像失败: {str(e)}")
return None
async def get_image_by_url(self, url: str) -> Optional[str]:
"""根据URL获取图像路径(带查重)
Args:
url: 图像URL
Returns:
str: 本地文件路径,不存在返回None
"""
try:
# 先查找是否已存在
existing = self.db.db.images.find_one({'url': url})
if existing:
return existing['path']
# 下载图像
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
image_bytes = await resp.read()
return await self.save_image(image_bytes, url=url)
return None
except Exception as e:
logger.error(f"获取图像失败: {str(e)}")
return None
async def get_base64_by_url(self, url: str) -> Optional[str]:
"""根据URL获取base64(带查重)
Args:
url: 图像URL
Returns:
str: base64字符串,失败返回None
"""
try:
image_path = await self.get_image_by_url(url)
if not image_path:
return None
with open(image_path, 'rb') as f:
image_bytes = f.read()
return base64.b64encode(image_bytes).decode('utf-8')
except Exception as e:
logger.error(f"获取base64失败: {str(e)}")
return None
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
def check_url_exists(self, url: str) -> bool:
"""检查URL是否已存在
Args:
url: 图像URL
Returns:
bool: 是否存在
"""
return self.db.db.images.find_one({'url': url}) is not None
return base64.b64encode(compressed_data).decode('utf-8')
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
Args:
image_data: 图像数据(base64或字节)
is_base64: 是否为base64格式
Returns:
bool: 是否存在
"""
try:
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return False
else:
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
return False
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'emoji')
if cached_description:
logger.info(f"缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
# 调用AI获取描述
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'emoji',
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存表情包: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
# 保存描述到数据库
self._save_description_to_db(image_hash, description, 'emoji')
return f"[表情包:{description}]"
except Exception as e:
logger.error(f"获取表情包描述失败: {str(e)}")
return "[表情包]"
async def get_image_description(self, image_base64: str) -> str:
"""获取普通图片描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'image')
if cached_description:
return f"[图片:{cached_description}]"
# 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR,'image', filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'image',
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存图片: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")
# 保存描述到数据库
self._save_description_to_db(image_hash, description, 'image')
return f"[图片:{description}]"
except Exception as e:
logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]"
# 创建全局单例
image_manager = ImageManager()
def image_path_to_base64(image_path: str) -> str:
"""将图片路径转换为base64编码

View File

@@ -1,10 +1,15 @@
import asyncio
from typing import Dict
from .config import global_config
from .chat_stream import ChatStream
class WillingManager:
def __init__(self):
self.group_reply_willing = {} # 存储每个的回复意愿
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self._decay_task = None
self._started = False
@@ -12,20 +17,38 @@ class WillingManager:
"""定期衰减回复意愿"""
while True:
await asyncio.sleep(5)
for group_id in self.group_reply_willing:
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
def get_willing(self, group_id: int) -> float:
"""获取指定群组的回复意愿"""
return self.group_reply_willing.get(group_id, 0)
def get_willing(self,chat_stream:ChatStream) -> float:
"""获取指定聊天流的回复意愿"""
stream = chat_stream
if stream:
return self.chat_reply_willing.get(stream.stream_id, 0)
return 0
def set_willing(self, group_id: int, willing: float):
"""设置指定群组的回复意愿"""
self.group_reply_willing[group_id] = willing
def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿"""
self.chat_reply_willing[chat_id] = willing
def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿"""
self.chat_reply_willing[chat_id] = willing
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float:
"""改变指定群组的回复意愿并返回回复概率"""
current_willing = self.group_reply_willing.get(group_id, 0)
async def change_reply_willing_received(self,
chat_stream:ChatStream,
topic: str = None,
is_mentioned_bot: bool = False,
config = None,
is_emoji: bool = False,
interested_rate: float = 0) -> float:
"""改变指定聊天流的回复意愿并返回回复概率"""
# 获取或创建聊天流
stream = chat_stream
chat_id = stream.stream_id
current_willing = self.chat_reply_willing.get(chat_id, 0)
# print(f"初始意愿: {current_willing}")
if is_mentioned_bot and current_willing < 1.0:
@@ -49,31 +72,33 @@ class WillingManager:
# print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
reply_probability = max((current_willing - 0.45) * 2, 0)
if group_id not in config.talk_allowed_groups:
current_willing = 0
reply_probability = 0
if group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / global_config.down_frequency_rate
# 检查群组权限(如果是群聊)
if chat_stream.group_info:
if chat_stream.group_info.group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / global_config.down_frequency_rate
reply_probability = min(reply_probability, 1)
if reply_probability < 0:
reply_probability = 0
self.group_reply_willing[group_id] = min(current_willing, 3.0)
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
return reply_probability
def change_reply_willing_sent(self, group_id: int):
"""开始思考后降低群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
self.group_reply_willing[group_id] = max(0, current_willing - 2)
def change_reply_willing_sent(self, chat_stream:ChatStream):
"""开始思考后降低聊天流的回复意愿"""
stream = chat_stream
if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2)
def change_reply_willing_after_sent(self, group_id: int):
"""发送消息后提高群组的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0)
if current_willing < 1:
self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
def change_reply_willing_after_sent(self,chat_stream:ChatStream):
"""发送消息后提高聊天流的回复意愿"""
stream = chat_stream
if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
if current_willing < 1:
self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2)
async def ensure_started(self):
"""确保衰减任务已启动"""
@@ -83,4 +108,4 @@ class WillingManager:
self._started = True
# 创建全局实例
willing_manager = WillingManager()
willing_manager = WillingManager()