diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 000000000..0d1e50c5a --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,8 @@ +name: Ruff +on: [ push, pull_request ] +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/ruff-action@v3 \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..8a04e2d84 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.9.10 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/bot.py b/bot.py index 48517fe24..a3a844a15 100644 --- a/bot.py +++ b/bot.py @@ -17,19 +17,6 @@ env_mask = {key: os.getenv(key) for key in os.environ} uvicorn_server = None -# 配置日志 -log_path = os.path.join(os.getcwd(), "logs") -if not os.path.exists(log_path): - os.makedirs(log_path) - -# 添加文件日志,启用rotation和retention -logger.add( - os.path.join(log_path, "maimbot_{time:YYYY-MM-DD}.log"), - rotation="00:00", # 每天0点创建新文件 - retention="30 days", # 保留30天的日志 - level="INFO", - encoding="utf-8" -) def easter_egg(): # 彩蛋 @@ -76,7 +63,7 @@ def init_env(): # 首先加载基础环境变量.env if os.path.exists(".env"): - load_dotenv(".env",override=True) + load_dotenv(".env", override=True) logger.success("成功加载基础环境变量配置") @@ -90,10 +77,7 @@ def load_env(): logger.success("加载开发环境变量配置") load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量 - fn_map = { - "prod": prod, - "dev": dev - } + fn_map = {"prod": prod, "dev": dev} env = os.getenv("ENVIRONMENT") logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}") @@ -109,28 +93,45 @@ def load_env(): logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") -def load_logger(): - logger.remove() # 移除默认配置 - if os.getenv("ENVIRONMENT") == "dev": - logger.add( - sys.stderr, - format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <7} | {name:.<8}:{function:.<8}:{line: >4} - {message}", - colorize=True, - level=os.getenv("LOG_LEVEL", "DEBUG"), # 根据环境设置日志级别,默认为DEBUG - ) - else: - logger.add( - sys.stderr, - format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <7} | {name:.<8}:{function:.<8}:{line: >4} - {message}", - colorize=True, - level=os.getenv("LOG_LEVEL", "INFO"), # 根据环境设置日志级别,默认为INFO - filter=lambda record: "nonebot" not in record["name"] - ) +def load_logger(): + logger.remove() + + # 配置日志基础路径 + log_path = os.path.join(os.getcwd(), "logs") + if not os.path.exists(log_path): + os.makedirs(log_path) + + current_env = os.getenv("ENVIRONMENT", "dev") + + # 公共配置参数 + log_level = os.getenv("LOG_LEVEL", "INFO" if current_env == "prod" else "DEBUG") + log_filter = lambda record: ( + ("nonebot" not in record["name"] or record["level"].no >= logger.level("ERROR").no) + if current_env == "prod" + else True + ) + log_format = ( + "{time:YYYY-MM-DD HH:mm:ss.SSS} " + "| {level: <7} " + "| {name:.<8}:{function:.<8}:{line: >4} " + "- {message}" + ) + + # 日志文件储存至/logs + logger.add( + os.path.join(log_path, "maimbot_{time:YYYY-MM-DD}.log"), + rotation="00:00", + retention="30 days", + format=log_format, + colorize=False, + level=log_level, + filter=log_filter, + encoding="utf-8", + ) + + # 终端输出 + logger.add(sys.stderr, format=log_format, colorize=True, level=log_level, filter=log_filter) def scan_provider(env_config: dict): @@ -160,10 +161,7 @@ def scan_provider(env_config: dict): # 检查每个 provider 是否同时存在 url 和 key for provider_name, config in provider.items(): if config["url"] is None or config["key"] is None: - logger.error( - f"provider 内容:{config}\n" - f"env_config 内容:{env_config}" - ) + logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") @@ -192,7 +190,7 @@ async def uvicorn_main(): reload=os.getenv("ENVIRONMENT") == "dev", timeout_graceful_shutdown=5, log_config=None, - access_log=False + access_log=False, ) server = uvicorn.Server(config) uvicorn_server = server @@ -202,7 +200,7 @@ async def uvicorn_main(): def raw_main(): # 利用 TZ 环境变量设定程序工作的时区 # 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用 - if platform.system().lower() != 'windows': + if platform.system().lower() != "windows": time.tzset() easter_egg() diff --git a/docs/avatars/SengokuCola.jpg b/docs/avatars/SengokuCola.jpg new file mode 100644 index 000000000..deebf5ed5 Binary files /dev/null and b/docs/avatars/SengokuCola.jpg differ diff --git a/docs/avatars/default.png b/docs/avatars/default.png new file mode 100644 index 000000000..5b561dac4 Binary files /dev/null and b/docs/avatars/default.png differ diff --git a/docs/avatars/run.bat b/docs/avatars/run.bat new file mode 100644 index 000000000..6b9ca9f2b --- /dev/null +++ b/docs/avatars/run.bat @@ -0,0 +1 @@ +gource gource.log --user-image-dir docs/avatars/ --default-user-image docs/avatars/default.png \ No newline at end of file diff --git a/src/common/database.py b/src/common/database.py index ca73dc468..cd149e526 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -22,9 +22,7 @@ def __create_database_instance(): if username and password: # 如果有用户名和密码,使用认证连接 - return MongoClient( - host, port, username=username, password=password, authSource=auth_source - ) + return MongoClient(host, port, username=username, password=password, authSource=auth_source) # 否则使用无认证连接 return MongoClient(host, port) diff --git a/src/gui/reasoning_gui.py b/src/gui/reasoning_gui.py index e79f8f91f..c577ba3ae 100644 --- a/src/gui/reasoning_gui.py +++ b/src/gui/reasoning_gui.py @@ -7,7 +7,7 @@ from datetime import datetime from typing import Dict, List from loguru import logger from typing import Optional -from ..common.database import db + import customtkinter as ctk from dotenv import load_dotenv @@ -16,6 +16,8 @@ from dotenv import load_dotenv current_dir = os.path.dirname(os.path.abspath(__file__)) # 获取项目根目录 root_dir = os.path.abspath(os.path.join(current_dir, '..', '..')) +sys.path.insert(0, root_dir) +from src.common.database import db # 加载环境变量 if os.path.exists(os.path.join(root_dir, '.env.dev')): diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index d7a7bd7e4..26b3d36da 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -3,8 +3,9 @@ import time import os from loguru import logger -from nonebot import get_driver, on_message, require -from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment,MessageEvent +from nonebot import get_driver, on_message, on_notice, require +from nonebot.rule import to_me +from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment, MessageEvent, NoticeEvent from nonebot.typing import T_State from ..moods.moods import MoodManager # 导入情绪管理器 @@ -39,6 +40,8 @@ logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......") chat_bot = ChatBot() # 注册消息处理器 msg_in = on_message(priority=5) +# 注册和bot相关的通知处理器 +notice_matcher = on_notice(priority=1) # 创建定时任务 scheduler = require("nonebot_plugin_apscheduler").scheduler @@ -95,19 +98,24 @@ async def _(bot: Bot, event: MessageEvent, state: T_State): await chat_bot.handle_message(event, bot) +@notice_matcher.handle() +async def _(bot: Bot, event: NoticeEvent, state: T_State): + logger.debug(f"收到通知:{event}") + await chat_bot.handle_notice(event, bot) + + # 添加build_memory定时任务 @scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory") async def build_memory_task(): """每build_memory_interval秒执行一次记忆构建""" - logger.debug( - "[记忆构建]" - "------------------------------------开始构建记忆--------------------------------------") + logger.debug("[记忆构建]------------------------------------开始构建记忆--------------------------------------") start_time = time.time() await hippocampus.operation_build_memory(chat_size=20) end_time = time.time() logger.success( f"[记忆构建]--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} " - "秒-------------------------------------------") + "秒-------------------------------------------" + ) @scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory") diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 1db38477c..b90b3d0f3 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -7,6 +7,8 @@ from nonebot.adapters.onebot.v11 import ( GroupMessageEvent, MessageEvent, PrivateMessageEvent, + NoticeEvent, + PokeNotifyEvent, ) from ..memory_system.memory import hippocampus @@ -25,6 +27,7 @@ from .relationship_manager import relationship_manager from .storage import MessageStorage from .utils import calculate_typing_time, is_mentioned_bot_in_message from .utils_image import image_path_to_base64 +from .utils_user import get_user_nickname, get_user_cardname, get_groupname from .willing_manager import willing_manager # 导入意愿管理器 from .message_base import UserInfo, GroupInfo, Seg @@ -46,6 +49,69 @@ class ChatBot: if not self._started: self._started = True + async def handle_notice(self, event: NoticeEvent, bot: Bot) -> None: + """处理收到的通知""" + # 戳一戳通知 + if isinstance(event, PokeNotifyEvent): + # 用户屏蔽,不区分私聊/群聊 + if event.user_id in global_config.ban_user_id: + return + reply_poke_probability = 1 # 回复戳一戳的概率 + + if random() < reply_poke_probability: + user_info = UserInfo( + user_id=event.user_id, + user_nickname=get_user_nickname(event.user_id) or None, + user_cardname=get_user_cardname(event.user_id) or None, + platform="qq", + ) + group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq") + message_cq = MessageRecvCQ( + message_id=None, + user_info=user_info, + raw_message=str("[戳了戳]你"), + group_info=group_info, + reply_message=None, + platform="qq", + ) + message_json = message_cq.to_dict() + + # 进入maimbot + message = MessageRecv(message_json) + groupinfo = message.message_info.group_info + userinfo = message.message_info.user_info + messageinfo = message.message_info + + chat = await chat_manager.get_or_create_stream( + platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo + ) + message.update_chat_stream(chat) + await message.process() + + bot_user_info = UserInfo( + user_id=global_config.BOT_QQ, + user_nickname=global_config.BOT_NICKNAME, + platform=messageinfo.platform, + ) + + response, raw_content = await self.gpt.generate_response(message) + + if response: + for msg in response: + message_segment = Seg(type="text", data=msg) + + bot_message = MessageSending( + message_id=None, + chat_stream=chat, + bot_user_info=bot_user_info, + sender_info=userinfo, + message_segment=message_segment, + reply=None, + is_head=False, + is_emoji=False, + ) + message_manager.add_message(bot_message) + async def handle_message(self, event: MessageEvent, bot: Bot) -> None: """处理收到的消息""" @@ -54,7 +120,10 @@ class ChatBot: # 用户屏蔽,不区分私聊/群聊 if event.user_id in global_config.ban_user_id: return - + + if event.reply and hasattr(event.reply, 'sender') and hasattr(event.reply.sender, 'user_id') and event.reply.sender.user_id in global_config.ban_user_id: + logger.debug(f"跳过处理回复来自被ban用户 {event.reply.sender.user_id} 的消息") + return # 处理私聊消息 if isinstance(event, PrivateMessageEvent): if not global_config.enable_friend_chat: # 私聊过滤 @@ -126,7 +195,7 @@ class ChatBot: for word in global_config.ban_words: if word in message.processed_plain_text: logger.info( - f"[{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{userinfo.user_nickname}:{message.processed_plain_text}" + f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{message.processed_plain_text}" ) logger.info(f"[过滤词识别]消息中含有{word},filtered") return @@ -135,7 +204,7 @@ class ChatBot: for pattern in global_config.ban_msgs_regex: if re.search(pattern, message.raw_message): logger.info( - f"[{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{userinfo.user_nickname}:{message.raw_message}" + f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{message.raw_message}" ) logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered") return @@ -143,7 +212,7 @@ class ChatBot: current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time)) # topic=await topic_identifier.identify_topic_llm(message.processed_plain_text) - + topic = "" interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100 logger.debug(f"对{message.processed_plain_text}的激活度:{interested_rate}") @@ -163,7 +232,7 @@ class ChatBot: current_willing = willing_manager.get_willing(chat_stream=chat) logger.info( - f"[{current_time}][{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{chat.user_info.user_nickname}:" + f"[{current_time}][{chat.group_info.group_name if chat.group_info else '私聊'}]{chat.user_info.user_nickname}:" f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]" ) diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index bc40cff80..049419f1c 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -86,9 +86,12 @@ class CQCode: else: self.translated_segments = Seg(type="text", data="[图片]") elif self.type == "at": - user_nickname = get_user_nickname(self.params.get("qq", "")) - self.translated_segments = Seg( - type="text", data=f"[@{user_nickname or '某人'}]" + if self.params.get("qq") == "all": + self.translated_segments = Seg(type="text", data="@[全体成员]") + else: + user_nickname = get_user_nickname(self.params.get("qq", "")) + self.translated_segments = Seg( + type="text", data=f"[@{user_nickname or '某人'}]" ) elif self.type == "reply": reply_segments = self.translate_reply() diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 822eda009..76437f8f2 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -36,9 +36,9 @@ class EmojiManager: def __init__(self): self._scan_task = None self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) - self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60, - temperature=0.8) # 更高的温度,更少的token(后续可以根据情绪来调整温度) - + self.llm_emotion_judge = LLM_request( + model=global_config.llm_emotion_judge, max_tokens=60, temperature=0.8 + ) # 更高的温度,更少的token(后续可以根据情绪来调整温度) def _ensure_emoji_dir(self): """确保表情存储目录存在""" @@ -65,42 +65,39 @@ class EmojiManager: def _ensure_emoji_collection(self): """确保emoji集合存在并创建索引 - + 这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。 - + 索引的作用是加快数据库查询速度: - embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包 - tags字段的普通索引: 加快按标签搜索表情包的速度 - filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度 - + 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 """ - if 'emoji' not in db.list_collection_names(): - db.create_collection('emoji') - db.emoji.create_index([('embedding', '2dsphere')]) - db.emoji.create_index([('filename', 1)], unique=True) + if "emoji" not in db.list_collection_names(): + db.create_collection("emoji") + db.emoji.create_index([("embedding", "2dsphere")]) + db.emoji.create_index([("filename", 1)], unique=True) def record_usage(self, emoji_id: str): """记录表情使用次数""" try: self._ensure_db() - db.emoji.update_one( - {'_id': emoji_id}, - {'$inc': {'usage_count': 1}} - ) + db.emoji.update_one({"_id": emoji_id}, {"$inc": {"usage_count": 1}}) except Exception as e: logger.error(f"记录表情使用失败: {str(e)}") - - async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str,str]]: + + async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str, str]]: """根据文本内容获取相关表情包 Args: text: 输入文本 Returns: Optional[str]: 表情包文件路径,如果没有找到则返回None - - + + 可不可以通过 配置文件中的指令 来自定义使用表情包的逻辑? - 我觉得可行 + 我觉得可行 """ try: @@ -118,7 +115,7 @@ class EmojiManager: try: # 获取所有表情包 - all_emojis = list(db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) + all_emojis = list(db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1})) if not all_emojis: logger.warning("数据库中没有任何表情包") @@ -137,34 +134,31 @@ class EmojiManager: # 计算所有表情包与输入文本的相似度 emoji_similarities = [ - (emoji, cosine_similarity(text_embedding, emoji.get('embedding', []))) - for emoji in all_emojis + (emoji, cosine_similarity(text_embedding, emoji.get("embedding", []))) for emoji in all_emojis ] # 按相似度降序排序 emoji_similarities.sort(key=lambda x: x[1], reverse=True) # 获取前3个最相似的表情包 - top_10_emojis = emoji_similarities[:10 if len(emoji_similarities) > 10 else len(emoji_similarities)] - + top_10_emojis = emoji_similarities[: 10 if len(emoji_similarities) > 10 else len(emoji_similarities)] + if not top_10_emojis: logger.warning("未找到匹配的表情包") return None # 从前3个中随机选择一个 selected_emoji, similarity = random.choice(top_10_emojis) - - if selected_emoji and 'path' in selected_emoji: + + if selected_emoji and "path" in selected_emoji: # 更新使用次数 - db.emoji.update_one( - {'_id': selected_emoji['_id']}, - {'$inc': {'usage_count': 1}} - ) + db.emoji.update_one({"_id": selected_emoji["_id"]}, {"$inc": {"usage_count": 1}}) logger.success( - f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})") + f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})" + ) # 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了 - return selected_emoji['path'], "[ %s ]" % selected_emoji.get('description', '无描述') + return selected_emoji["path"], "[ %s ]" % selected_emoji.get("description", "无描述") except Exception as search_error: logger.error(f"搜索表情包失败: {str(search_error)}") @@ -176,7 +170,6 @@ class EmojiManager: logger.error(f"获取表情包失败: {str(e)}") return None - async def _get_emoji_discription(self, image_base64: str) -> str: """获取表情包的标签,使用image_manager的描述生成功能""" @@ -184,16 +177,16 @@ class EmojiManager: # 使用image_manager获取描述,去掉前后的方括号和"表情包:"前缀 description = await image_manager.get_emoji_description(image_base64) # 去掉[表情包:xxx]的格式,只保留描述内容 - description = description.strip('[]').replace('表情包:', '') + description = description.strip("[]").replace("表情包:", "") return description - + except Exception as e: logger.error(f"获取标签失败: {str(e)}") return None async def _check_emoji(self, image_base64: str, image_format: str) -> str: try: - prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容' + prompt = f'这是一个表情包,请回答这个表情包是否满足"{global_config.EMOJI_CHECK_PROMPT}"的要求,是则回答是,否则回答否,不要出现任何其他内容' content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) logger.debug(f"输出描述: {content}") @@ -205,9 +198,9 @@ class EmojiManager: async def _get_kimoji_for_text(self, text: str): try: - prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。' + prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出"一种什么样的感觉"中间的形容词部分。' - content, _ = await self.llm_emotion_judge.generate_response_async(prompt,temperature=1.5) + content, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=1.5) logger.info(f"输出描述: {content}") return content @@ -222,63 +215,58 @@ class EmojiManager: os.makedirs(emoji_dir, exist_ok=True) # 获取所有支持的图片文件 - files_to_process = [f for f in os.listdir(emoji_dir) if - f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))] + files_to_process = [ + f for f in os.listdir(emoji_dir) if f.lower().endswith((".jpg", ".jpeg", ".png", ".gif")) + ] for filename in files_to_process: image_path = os.path.join(emoji_dir, filename) - + # 获取图片的base64编码和哈希值 image_base64 = image_path_to_base64(image_path) if image_base64 is None: os.remove(image_path) continue - + image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # 检查是否已经注册过 - existing_emoji = db['emoji'].find_one({'filename': filename}) + existing_emoji = db["emoji"].find_one({"hash": image_hash}) description = None - + if existing_emoji: # 即使表情包已存在,也检查是否需要同步到images集合 - description = existing_emoji.get('discription') + description = existing_emoji.get("discription") # 检查是否在images集合中存在 - existing_image = db.images.find_one({'hash': image_hash}) + existing_image = db.images.find_one({"hash": image_hash}) if not existing_image: # 同步到images集合 image_doc = { - 'hash': image_hash, - 'path': image_path, - 'type': 'emoji', - 'description': description, - 'timestamp': int(time.time()) + "hash": image_hash, + "path": image_path, + "type": "emoji", + "description": description, + "timestamp": int(time.time()), } - db.images.update_one( - {'hash': image_hash}, - {'$set': image_doc}, - upsert=True - ) + db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) # 保存描述到image_descriptions集合 - image_manager._save_description_to_db(image_hash, description, 'emoji') + image_manager._save_description_to_db(image_hash, description, "emoji") logger.success(f"同步已存在的表情包到images集合: {filename}") continue - + # 检查是否在images集合中已有描述 - existing_description = image_manager._get_description_from_db(image_hash, 'emoji') - + existing_description = image_manager._get_description_from_db(image_hash, "emoji") + if existing_description: description = existing_description else: # 获取表情包的描述 description = await self._get_emoji_discription(image_base64) - - if global_config.EMOJI_CHECK: check = await self._check_emoji(image_base64, image_format) - if '是' not in check: + if "是" not in check: os.remove(image_path) logger.info(f"描述: {description}") @@ -286,44 +274,39 @@ class EmojiManager: logger.info(f"其不满足过滤规则,被剔除 {check}") continue logger.info(f"check通过 {check}") - + if description is not None: embedding = await get_embedding(description) - + if description is not None: embedding = await get_embedding(description) # 准备数据库记录 emoji_record = { - 'filename': filename, - 'path': image_path, - 'embedding': embedding, - 'discription': description, - 'hash': image_hash, - 'timestamp': int(time.time()) + "filename": filename, + "path": image_path, + "embedding": embedding, + "discription": description, + "hash": image_hash, + "timestamp": int(time.time()), } - + # 保存到emoji数据库 - db['emoji'].insert_one(emoji_record) + db["emoji"].insert_one(emoji_record) logger.success(f"注册新表情包: {filename}") logger.info(f"描述: {description}") - # 保存到images数据库 image_doc = { - 'hash': image_hash, - 'path': image_path, - 'type': 'emoji', - 'description': description, - 'timestamp': int(time.time()) + "hash": image_hash, + "path": image_path, + "type": "emoji", + "description": description, + "timestamp": int(time.time()), } - db.images.update_one( - {'hash': image_hash}, - {'$set': image_doc}, - upsert=True - ) + db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) # 保存描述到image_descriptions集合 - image_manager._save_description_to_db(image_hash, description, 'emoji') + image_manager._save_description_to_db(image_hash, description, "emoji") logger.success(f"同步保存到images集合: {filename}") else: logger.warning(f"跳过表情包: {filename}") @@ -351,28 +334,35 @@ class EmojiManager: for emoji in all_emojis: try: - if 'path' not in emoji: + if "path" not in emoji: logger.warning(f"发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}") - db.emoji.delete_one({'_id': emoji['_id']}) + db.emoji.delete_one({"_id": emoji["_id"]}) removed_count += 1 continue - if 'embedding' not in emoji: + if "embedding" not in emoji: logger.warning(f"发现过时记录(缺少embedding字段),ID: {emoji.get('_id', 'unknown')}") - db.emoji.delete_one({'_id': emoji['_id']}) + db.emoji.delete_one({"_id": emoji["_id"]}) removed_count += 1 continue # 检查文件是否存在 - if not os.path.exists(emoji['path']): + if not os.path.exists(emoji["path"]): logger.warning(f"表情包文件已被删除: {emoji['path']}") # 从数据库中删除记录 - result = db.emoji.delete_one({'_id': emoji['_id']}) + result = db.emoji.delete_one({"_id": emoji["_id"]}) if result.deleted_count > 0: logger.debug(f"成功删除数据库记录: {emoji['_id']}") removed_count += 1 else: logger.error(f"删除数据库记录失败: {emoji['_id']}") + continue + + if "hash" not in emoji: + logger.warning(f"发现缺失记录(缺少hash字段),ID: {emoji.get('_id', 'unknown')}") + hash = hashlib.md5(open(emoji["path"], "rb").read()).hexdigest() + db.emoji.update_one({"_id": emoji["_id"]}, {"$set": {"hash": hash}}) + except Exception as item_error: logger.error(f"处理表情包记录时出错: {str(item_error)}") continue @@ -398,5 +388,3 @@ class EmojiManager: # 创建全局单例 emoji_manager = EmojiManager() - - diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index 626e7cf4e..96308c50b 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -23,8 +23,8 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @dataclass class Message(MessageBase): - chat_stream: ChatStream=None - reply: Optional['Message'] = None + chat_stream: ChatStream = None + reply: Optional["Message"] = None detailed_plain_text: str = "" processed_plain_text: str = "" @@ -35,7 +35,7 @@ class Message(MessageBase): chat_stream: ChatStream, user_info: UserInfo, message_segment: Optional[Seg] = None, - reply: Optional['MessageRecv'] = None, + reply: Optional["MessageRecv"] = None, detailed_plain_text: str = "", processed_plain_text: str = "", ): @@ -45,21 +45,17 @@ class Message(MessageBase): message_id=message_id, time=time, group_info=chat_stream.group_info, - user_info=user_info + user_info=user_info, ) # 调用父类初始化 - super().__init__( - message_info=message_info, - message_segment=message_segment, - raw_message=None - ) + super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) self.chat_stream = chat_stream # 文本处理相关属性 self.processed_plain_text = processed_plain_text self.detailed_plain_text = detailed_plain_text - + # 回复消息 self.reply = reply @@ -74,41 +70,38 @@ class MessageRecv(Message): Args: message_dict: MessageCQ序列化后的字典 """ - self.message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {})) + self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) - message_segment = message_dict.get('message_segment', {}) + message_segment = message_dict.get("message_segment", {}) - if message_segment.get('data','') == '[json]': + if message_segment.get("data", "") == "[json]": # 提取json消息中的展示信息 - pattern = r'\[CQ:json,data=(?P.+?)\]' - match = re.search(pattern, message_dict.get('raw_message','')) - raw_json = html.unescape(match.group('json_data')) + pattern = r"\[CQ:json,data=(?P.+?)\]" + match = re.search(pattern, message_dict.get("raw_message", "")) + raw_json = html.unescape(match.group("json_data")) try: json_message = json.loads(raw_json) except json.JSONDecodeError: json_message = {} - message_segment['data'] = json_message.get('prompt','') + message_segment["data"] = json_message.get("prompt", "") + + self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) + self.raw_message = message_dict.get("raw_message") - self.message_segment = Seg.from_dict(message_dict.get('message_segment', {})) - self.raw_message = message_dict.get('raw_message') - # 处理消息内容 self.processed_plain_text = "" # 初始化为空字符串 - self.detailed_plain_text = "" # 初始化为空字符串 - self.is_emoji=False - - - def update_chat_stream(self,chat_stream:ChatStream): - self.chat_stream=chat_stream - + self.detailed_plain_text = "" # 初始化为空字符串 + self.is_emoji = False + + def update_chat_stream(self, chat_stream: ChatStream): + self.chat_stream = chat_stream + async def process(self) -> None: """处理消息内容,生成纯文本和详细文本 这个方法必须在创建实例后显式调用,因为它包含异步操作。 """ - self.processed_plain_text = await self._process_message_segments( - self.message_segment - ) + self.processed_plain_text = await self._process_message_segments(self.message_segment) self.detailed_plain_text = self._generate_detailed_text() async def _process_message_segments(self, segment: Seg) -> str: @@ -157,16 +150,12 @@ class MessageRecv(Message): else: return f"[{seg.type}:{str(seg.data)}]" except Exception as e: - logger.error( - f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}" - ) + logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}") return f"[处理失败的{seg.type}消息]" def _generate_detailed_text(self) -> str: """生成详细文本,包含时间和用户信息""" - time_str = time.strftime( - "%m-%d %H:%M:%S", time.localtime(self.message_info.time) - ) + time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time)) user_info = self.message_info.user_info name = ( f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})" @@ -174,7 +163,7 @@ class MessageRecv(Message): else f"{user_info.user_nickname}(ta的id:{user_info.user_id})" ) return f"[{time_str}] {name}: {self.processed_plain_text}\n" - + @dataclass class MessageProcessBase(Message): @@ -257,16 +246,12 @@ class MessageProcessBase(Message): else: return f"[{seg.type}:{str(seg.data)}]" except Exception as e: - logger.error( - f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}" - ) + logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}") return f"[处理失败的{seg.type}消息]" def _generate_detailed_text(self) -> str: """生成详细文本,包含时间和用户信息""" - time_str = time.strftime( - "%m-%d %H:%M:%S", time.localtime(self.message_info.time) - ) + time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time)) user_info = self.message_info.user_info name = ( f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})" @@ -330,10 +315,11 @@ class MessageSending(MessageProcessBase): self.is_head = is_head self.is_emoji = is_emoji - def set_reply(self, reply: Optional["MessageRecv"]) -> None: + def set_reply(self, reply: Optional["MessageRecv"] = None) -> None: """设置回复消息""" if reply: self.reply = reply + if self.reply: self.reply_to_message_id = self.reply.message_info.message_id self.message_segment = Seg( type="seglist", @@ -346,9 +332,7 @@ class MessageSending(MessageProcessBase): async def process(self) -> None: """处理消息内容,生成纯文本和详细文本""" if self.message_segment: - self.processed_plain_text = await self._process_message_segments( - self.message_segment - ) + self.processed_plain_text = await self._process_message_segments(self.message_segment) self.detailed_plain_text = self._generate_detailed_text() @classmethod @@ -377,10 +361,7 @@ class MessageSending(MessageProcessBase): def is_private_message(self) -> bool: """判断是否为私聊消息""" - return ( - self.message_info.group_info is None - or self.message_info.group_info.group_id is None - ) + return self.message_info.group_info is None or self.message_info.group_info.group_id is None @dataclass diff --git a/src/plugins/chat/message_base.py b/src/plugins/chat/message_base.py index ae7ec3872..80b8b6618 100644 --- a/src/plugins/chat/message_base.py +++ b/src/plugins/chat/message_base.py @@ -65,6 +65,8 @@ class GroupInfo: Returns: GroupInfo: 新的实例 """ + if data.get('group_id') is None: + return None return cls( platform=data.get('platform'), group_id=data.get('group_id'), @@ -129,8 +131,8 @@ class BaseMessageInfo: Returns: BaseMessageInfo: 新的实例 """ - group_info = GroupInfo(**data.get('group_info', {})) - user_info = UserInfo(**data.get('user_info', {})) + group_info = GroupInfo.from_dict(data.get('group_info', {})) + user_info = UserInfo.from_dict(data.get('user_info', {})) return cls( platform=data.get('platform'), message_id=data.get('message_id'), @@ -173,7 +175,7 @@ class MessageBase: Returns: MessageBase: 新的实例 """ - message_info = BaseMessageInfo(**data.get('message_info', {})) + message_info = BaseMessageInfo.from_dict(data.get('message_info', {})) message_segment = Seg(**data.get('message_segment', {})) raw_message = data.get('raw_message',None) return cls( diff --git a/src/plugins/chat/message_cq.py b/src/plugins/chat/message_cq.py index 59d67a455..4c46d3bf2 100644 --- a/src/plugins/chat/message_cq.py +++ b/src/plugins/chat/message_cq.py @@ -8,48 +8,40 @@ from .cq_code import cq_code_tool from .utils_cq import parse_cq_code from .utils_user import get_groupname from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase + # 禁用SSL警告 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) -#这个类是消息数据类,用于存储和管理消息数据。 -#它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。 -#它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。 +# 这个类是消息数据类,用于存储和管理消息数据。 +# 它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。 +# 它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。 + @dataclass class MessageCQ(MessageBase): """QQ消息基类,继承自MessageBase - + 最小必要参数: - message_id: 消息ID - user_id: 发送者/接收者ID - platform: 平台标识(默认为"qq") """ + def __init__( - self, - message_id: int, - user_info: UserInfo, - group_info: Optional[GroupInfo] = None, - platform: str = "qq" + self, message_id: int, user_info: UserInfo, group_info: Optional[GroupInfo] = None, platform: str = "qq" ): # 构造基础消息信息 message_info = BaseMessageInfo( - platform=platform, - message_id=message_id, - time=int(time.time()), - group_info=group_info, - user_info=user_info + platform=platform, message_id=message_id, time=int(time.time()), group_info=group_info, user_info=user_info ) # 调用父类初始化,message_segment 由子类设置 - super().__init__( - message_info=message_info, - message_segment=None, - raw_message=None - ) + super().__init__(message_info=message_info, message_segment=None, raw_message=None) + @dataclass class MessageRecvCQ(MessageCQ): """QQ接收消息类,用于解析raw_message到Seg对象""" - + def __init__( self, message_id: int, @@ -61,14 +53,14 @@ class MessageRecvCQ(MessageCQ): ): # 调用父类初始化 super().__init__(message_id, user_info, group_info, platform) - + # 私聊消息不携带group_info if group_info is None: pass elif group_info.group_name is None: group_info.group_name = get_groupname(group_info.group_id) - + # 解析消息段 self.message_segment = self._parse_message(raw_message, reply_message) self.raw_message = raw_message @@ -77,10 +69,10 @@ class MessageRecvCQ(MessageCQ): """解析消息内容为Seg对象""" cq_code_dict_list = [] segments = [] - + start = 0 while True: - cq_start = message.find('[CQ:', start) + cq_start = message.find("[CQ:", start) if cq_start == -1: if start < len(message): text = message[start:].strip() @@ -93,81 +85,80 @@ class MessageRecvCQ(MessageCQ): if text: cq_code_dict_list.append(parse_cq_code(text)) - cq_end = message.find(']', cq_start) + cq_end = message.find("]", cq_start) if cq_end == -1: text = message[cq_start:].strip() if text: cq_code_dict_list.append(parse_cq_code(text)) break - cq_code = message[cq_start:cq_end + 1] + cq_code = message[cq_start : cq_end + 1] cq_code_dict_list.append(parse_cq_code(cq_code)) start = cq_end + 1 # 转换CQ码为Seg对象 for code_item in cq_code_dict_list: - message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message) + message_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message) if message_obj.translated_segments: segments.append(message_obj.translated_segments) # 如果只有一个segment,直接返回 if len(segments) == 1: return segments[0] - + # 否则返回seglist类型的Seg - return Seg(type='seglist', data=segments) + return Seg(type="seglist", data=segments) def to_dict(self) -> Dict: """转换为字典格式,包含所有必要信息""" base_dict = super().to_dict() return base_dict + @dataclass class MessageSendCQ(MessageCQ): """QQ发送消息类,用于将Seg对象转换为raw_message""" - - def __init__( - self, - data: Dict - ): + + def __init__(self, data: Dict): # 调用父类初始化 - message_info = BaseMessageInfo.from_dict(data.get('message_info', {})) - message_segment = Seg.from_dict(data.get('message_segment', {})) + message_info = BaseMessageInfo.from_dict(data.get("message_info", {})) + message_segment = Seg.from_dict(data.get("message_segment", {})) super().__init__( - message_info.message_id, - message_info.user_info, - message_info.group_info if message_info.group_info else None, - message_info.platform - ) - + message_info.message_id, + message_info.user_info, + message_info.group_info if message_info.group_info else None, + message_info.platform, + ) + self.message_segment = message_segment self.raw_message = self._generate_raw_message() - def _generate_raw_message(self, ) -> str: + def _generate_raw_message( + self, + ) -> str: """将Seg对象转换为raw_message""" segments = [] # 处理消息段 - if self.message_segment.type == 'seglist': + if self.message_segment.type == "seglist": for seg in self.message_segment.data: segments.append(self._seg_to_cq_code(seg)) else: segments.append(self._seg_to_cq_code(self.message_segment)) - return ''.join(segments) + return "".join(segments) def _seg_to_cq_code(self, seg: Seg) -> str: """将单个Seg对象转换为CQ码字符串""" - if seg.type == 'text': + if seg.type == "text": return str(seg.data) - elif seg.type == 'image': + elif seg.type == "image": return cq_code_tool.create_image_cq_base64(seg.data) - elif seg.type == 'emoji': + elif seg.type == "emoji": return cq_code_tool.create_emoji_cq_base64(seg.data) - elif seg.type == 'at': + elif seg.type == "at": return f"[CQ:at,qq={seg.data}]" - elif seg.type == 'reply': + elif seg.type == "reply": return cq_code_tool.create_reply_cq(int(seg.data)) else: return f"[{seg.data}]" - diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 2154280de..cc3a6ca3d 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -13,19 +13,21 @@ from nonebot import get_driver from ...common.database import db from ..chat.config import global_config from ..models.utils_model import LLM_request + driver = get_driver() config = driver.config + class ImageManager: _instance = None IMAGE_DIR = "data" # 图像存储根目录 - + def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance - + def __init__(self): if not self._initialized: self._ensure_image_collection() @@ -33,124 +35,55 @@ class ImageManager: self._ensure_image_dir() self._initialized = True self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) - + def _ensure_image_dir(self): """确保图像存储目录存在""" os.makedirs(self.IMAGE_DIR, exist_ok=True) - + def _ensure_image_collection(self): """确保images集合存在并创建索引""" - if 'images' not in db.list_collection_names(): - db.create_collection('images') + if "images" not in db.list_collection_names(): + db.create_collection("images") # 创建索引 - db.images.create_index([('hash', 1)], unique=True) - db.images.create_index([('url', 1)]) - db.images.create_index([('path', 1)]) + db.images.create_index([("hash", 1)], unique=True) + db.images.create_index([("url", 1)]) + db.images.create_index([("path", 1)]) def _ensure_description_collection(self): """确保image_descriptions集合存在并创建索引""" - if 'image_descriptions' not in db.list_collection_names(): - db.create_collection('image_descriptions') + if "image_descriptions" not in db.list_collection_names(): + db.create_collection("image_descriptions") # 创建索引 - db.image_descriptions.create_index([('hash', 1)], unique=True) - db.image_descriptions.create_index([('type', 1)]) + db.image_descriptions.create_index([("hash", 1)], unique=True) + db.image_descriptions.create_index([("type", 1)]) def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 - + Args: image_hash: 图片哈希值 description_type: 描述类型 ('emoji' 或 'image') - + Returns: Optional[str]: 描述文本,如果不存在则返回None """ - result= db.image_descriptions.find_one({ - 'hash': image_hash, - 'type': description_type - }) - return result['description'] if result else None + result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type}) + return result["description"] if result else None def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None: """保存图片描述到数据库 - + Args: image_hash: 图片哈希值 description: 描述文本 description_type: 描述类型 ('emoji' 或 'image') """ db.image_descriptions.update_one( - {'hash': image_hash, 'type': description_type}, - { - '$set': { - 'description': description, - 'timestamp': int(time.time()) - } - }, - upsert=True + {"hash": image_hash, "type": description_type}, + {"$set": {"description": description, "timestamp": int(time.time())}}, + upsert=True, ) - async def save_image(self, - image_data: Union[str, bytes], - url: str = None, - description: str = None, - is_base64: bool = False) -> Optional[str]: - """保存图像 - Args: - image_data: 图像数据(base64字符串或字节) - url: 图像URL - description: 图像描述 - is_base64: image_data是否为base64格式 - Returns: - str: 保存后的文件路径,失败返回None - """ - try: - # 转换为字节格式 - if is_base64: - if isinstance(image_data, str): - image_bytes = base64.b64decode(image_data) - else: - return None - else: - if isinstance(image_data, bytes): - image_bytes = image_data - else: - return None - - # 计算哈希值 - image_hash = hashlib.md5(image_bytes).hexdigest() - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() - - # 查重 - existing = db.images.find_one({'hash': image_hash}) - if existing: - return existing['path'] - - # 生成文件名和路径 - timestamp = int(time.time()) - filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - file_path = os.path.join(self.IMAGE_DIR, filename) - - # 保存文件 - with open(file_path, "wb") as f: - f.write(image_bytes) - - # 保存到数据库 - image_doc = { - 'hash': image_hash, - 'path': file_path, - 'url': url, - 'description': description, - 'timestamp': timestamp - } - db.images.insert_one(image_doc) - - return file_path - - except Exception as e: - logger.error(f"保存图像失败: {str(e)}") - return None - async def get_image_by_url(self, url: str) -> Optional[str]: """根据URL获取图像路径(带查重) Args: @@ -160,10 +93,10 @@ class ImageManager: """ try: # 先查找是否已存在 - existing = db.images.find_one({'url': url}) + existing = db.images.find_one({"url": url}) if existing: - return existing['path'] - + return existing["path"] + # 下载图像 async with aiohttp.ClientSession() as session: async with session.get(url) as resp: @@ -171,68 +104,11 @@ class ImageManager: image_bytes = await resp.read() return await self.save_image(image_bytes, url=url) return None - + except Exception as e: logger.error(f"获取图像失败: {str(e)}") return None - - async def get_base64_by_url(self, url: str) -> Optional[str]: - """根据URL获取base64(带查重) - Args: - url: 图像URL - Returns: - str: base64字符串,失败返回None - """ - try: - image_path = await self.get_image_by_url(url) - if not image_path: - return None - - with open(image_path, 'rb') as f: - image_bytes = f.read() - return base64.b64encode(image_bytes).decode('utf-8') - - except Exception as e: - logger.error(f"获取base64失败: {str(e)}") - return None - - - def check_url_exists(self, url: str) -> bool: - """检查URL是否已存在 - Args: - url: 图像URL - Returns: - bool: 是否存在 - """ - return db.images.find_one({'url': url}) is not None - - def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool: - """检查图像是否已存在 - Args: - image_data: 图像数据(base64或字节) - is_base64: 是否为base64格式 - Returns: - bool: 是否存在 - """ - try: - if is_base64: - if isinstance(image_data, str): - image_bytes = base64.b64decode(image_data) - else: - return False - else: - if isinstance(image_data, bytes): - image_bytes = image_data - else: - return False - - image_hash = hashlib.md5(image_bytes).hexdigest() - return db.images.find_one({'hash': image_hash}) is not None - - except Exception as e: - logger.error(f"检查哈希失败: {str(e)}") - return False - + async def get_emoji_description(self, image_base64: str) -> str: """获取表情包描述,带查重和保存功能""" try: @@ -242,7 +118,7 @@ class ImageManager: image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # 查询缓存的描述 - cached_description = self._get_description_from_db(image_hash, 'emoji') + cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description: logger.info(f"缓存表情包描述: {cached_description}") return f"[表情包:{cached_description}]" @@ -250,39 +126,42 @@ class ImageManager: # 调用AI获取描述 prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感" description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) - + + cached_description = self._get_description_from_db(image_hash, "emoji") + if cached_description: + logger.warning(f"虽然生成了描述,但找到缓存表情包描述: {cached_description}") + return f"[表情包:{cached_description}]" + # 根据配置决定是否保存图片 if global_config.EMOJI_SAVE: # 生成文件名和路径 timestamp = int(time.time()) filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename) - + if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")): + os.makedirs(os.path.join(self.IMAGE_DIR, "emoji")) + file_path = os.path.join(self.IMAGE_DIR, "emoji", filename) + try: # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) - + # 保存到数据库 image_doc = { - 'hash': image_hash, - 'path': file_path, - 'type': 'emoji', - 'description': description, - 'timestamp': timestamp + "hash": image_hash, + "path": file_path, + "type": "emoji", + "description": description, + "timestamp": timestamp, } - db.images.update_one( - {'hash': image_hash}, - {'$set': image_doc}, - upsert=True - ) + db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) logger.success(f"保存表情包: {file_path}") except Exception as e: logger.error(f"保存表情包文件失败: {str(e)}") - + # 保存描述到数据库 - self._save_description_to_db(image_hash, description, 'emoji') - + self._save_description_to_db(image_hash, description, "emoji") + return f"[表情包:{description}]" except Exception as e: logger.error(f"获取表情包描述失败: {str(e)}") @@ -298,60 +177,63 @@ class ImageManager: image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # 查询缓存的描述 - cached_description = self._get_description_from_db(image_hash, 'image') + cached_description = self._get_description_from_db(image_hash, "image") if cached_description: print("图片描述缓存中") return f"[图片:{cached_description}]" # 调用AI获取描述 - prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" + prompt = ( + "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" + ) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) - + cached_description = self._get_description_from_db(image_hash, "emoji") + if cached_description: + logger.info(f"缓存图片描述: {cached_description}") + return f"[图片:{cached_description}]" + print(f"描述是{description}") - + if description is None: logger.warning("AI未能生成图片描述") return "[图片]" - + # 根据配置决定是否保存图片 if global_config.EMOJI_SAVE: # 生成文件名和路径 timestamp = int(time.time()) filename = f"{timestamp}_{image_hash[:8]}.{image_format}" - file_path = os.path.join(self.IMAGE_DIR,'image', filename) - + if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")): + os.makedirs(os.path.join(self.IMAGE_DIR, "image")) + file_path = os.path.join(self.IMAGE_DIR, "image", filename) + try: # 保存文件 with open(file_path, "wb") as f: f.write(image_bytes) - + # 保存到数据库 image_doc = { - 'hash': image_hash, - 'path': file_path, - 'type': 'image', - 'description': description, - 'timestamp': timestamp + "hash": image_hash, + "path": file_path, + "type": "image", + "description": description, + "timestamp": timestamp, } - db.images.update_one( - {'hash': image_hash}, - {'$set': image_doc}, - upsert=True - ) + db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True) logger.success(f"保存图片: {file_path}") except Exception as e: logger.error(f"保存图片文件失败: {str(e)}") - + # 保存描述到数据库 - self._save_description_to_db(image_hash, description, 'image') - + self._save_description_to_db(image_hash, description, "image") + return f"[图片:{description}]" except Exception as e: logger.error(f"获取图片描述失败: {str(e)}") return "[图片]" - # 创建全局单例 image_manager = ImageManager() @@ -364,9 +246,9 @@ def image_path_to_base64(image_path: str) -> str: str: base64编码的图片数据 """ try: - with open(image_path, 'rb') as f: + with open(image_path, "rb") as f: image_data = f.read() - return base64.b64encode(image_data).decode('utf-8') + return base64.b64encode(image_data).decode("utf-8") except Exception as e: logger.error(f"读取图片失败: {image_path}, 错误: {str(e)}") - return None \ No newline at end of file + return None diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index afe4baeb5..0f5bb335c 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -132,7 +132,7 @@ class LLM_request: # 常见Error Code Mapping error_code_mapping = { 400: "参数不正确", - 401: "API key 错误,认证失败", + 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env.prod中的配置是否正确哦~", 402: "账号余额不足", 403: "需要实名,或余额不足", 404: "Not Found", diff --git a/template.env b/template.env index d2a763112..322776ce7 100644 --- a/template.env +++ b/template.env @@ -23,7 +23,7 @@ CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 -#定义你要用的api的base_url +#定义你要用的api的key(需要去对应网站申请哦) DEEP_SEEK_KEY= CHAT_ANY_WHERE_KEY= SILICONFLOW_KEY=