diff --git a/bot.py b/bot.py index 30714e846..bd28e6cee 100644 --- a/bot.py +++ b/bot.py @@ -4,15 +4,11 @@ import os import shutil import sys from pathlib import Path - -import nonebot import time - -import uvicorn -from dotenv import load_dotenv -from nonebot.adapters.onebot.v11 import Adapter import platform +from dotenv import load_dotenv from src.common.logger import get_module_logger +from src.main import MainSystem logger = get_module_logger("main_bot") @@ -134,11 +130,7 @@ def scan_provider(env_config: dict): async def graceful_shutdown(): try: - global uvicorn_server - if uvicorn_server: - uvicorn_server.force_exit = True # 强制退出 - await uvicorn_server.shutdown() - + logger.info("正在优雅关闭麦麦...") tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] for task in tasks: task.cancel() @@ -148,22 +140,6 @@ async def graceful_shutdown(): logger.error(f"麦麦关闭失败: {e}") -async def uvicorn_main(): - global uvicorn_server - config = uvicorn.Config( - app="__main__:app", - host=os.getenv("HOST", "127.0.0.1"), - port=int(os.getenv("PORT", 8080)), - reload=os.getenv("ENVIRONMENT") == "dev", - timeout_graceful_shutdown=5, - log_config=None, - access_log=False, - ) - server = uvicorn.Server(config) - uvicorn_server = server - await server.serve() - - def check_eula(): eula_confirm_file = Path("eula.confirmed") privacy_confirm_file = Path("privacy.confirmed") @@ -245,7 +221,6 @@ def check_eula(): def raw_main(): # 利用 TZ 环境变量设定程序工作的时区 - # 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用 if platform.system().lower() != "windows": time.tzset() @@ -256,40 +231,26 @@ def raw_main(): init_env() load_env() - # load_logger() - env_config = {key: os.getenv(key) for key in os.environ} scan_provider(env_config) - # 设置基础配置 - base_config = { - "websocket_port": int(env_config.get("PORT", 8080)), - "host": env_config.get("HOST", "127.0.0.1"), - "log_level": "INFO", - } - - # 合并配置 - nonebot.init(**base_config, **env_config) - - # 注册适配器 - global driver - driver = nonebot.get_driver() - driver.register_adapter(Adapter) - - # 加载插件 - nonebot.load_plugins("src/plugins") + # 返回MainSystem实例 + return MainSystem() if __name__ == "__main__": try: - raw_main() + # 获取MainSystem实例 + main_system = raw_main() - app = nonebot.get_asgi() + # 创建事件循环 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - loop.run_until_complete(uvicorn_main()) + # 执行初始化和任务调度 + loop.run_until_complete(main_system.initialize()) + loop.run_until_complete(main_system.schedule_tasks()) except KeyboardInterrupt: logger.warning("收到中断信号,正在优雅关闭...") loop.run_until_complete(graceful_shutdown()) diff --git a/src/main.py b/src/main.py index c32800dcc..fa8100d85 100644 --- a/src/main.py +++ b/src/main.py @@ -1,21 +1,19 @@ import asyncio import time from datetime import datetime - -from plugins.utils.statistic import LLMStatistics -from plugins.moods.moods import MoodManager -from plugins.schedule.schedule_generator import bot_schedule -from plugins.chat.emoji_manager import emoji_manager -from plugins.chat.relationship_manager import relationship_manager -from plugins.willing.willing_manager import willing_manager -from plugins.chat.chat_stream import chat_manager -from plugins.memory_system.memory import hippocampus -from plugins.chat.message_sender import message_manager -from plugins.chat.storage import MessageStorage -from plugins.chat.config import global_config -from common.logger import get_module_logger -from fastapi import FastAPI -from plugins.chat.api import app as api_app +from .plugins.utils.statistic import LLMStatistics +from .plugins.moods.moods import MoodManager +from .plugins.schedule.schedule_generator import bot_schedule +from .plugins.chat.emoji_manager import emoji_manager +from .plugins.chat.relationship_manager import relationship_manager +from .plugins.willing.willing_manager import willing_manager +from .plugins.chat.chat_stream import chat_manager +from .plugins.memory_system.memory import hippocampus +from .plugins.chat.message_sender import message_manager +from .plugins.chat.storage import MessageStorage +from .plugins.chat.config import global_config +from .plugins.chat.bot import chat_bot +from .common.logger import get_module_logger logger = get_module_logger("main") @@ -25,13 +23,29 @@ class MainSystem: self.llm_stats = LLMStatistics("llm_statistics.txt") self.mood_manager = MoodManager.get_instance() self._message_manager_started = False - self.app = FastAPI() - self.app.mount("/chat", api_app) + + # 使用消息API替代直接的FastAPI实例 + from .plugins.message import global_api + + self.app = global_api async def initialize(self): """初始化系统组件""" logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......") + # 启动API服务器(改为异步启动) + self.api_task = asyncio.create_task(self.app.run()) + + # 其他初始化任务 + await asyncio.gather( + self._init_components(), # 将原有的初始化代码移到这个新方法中 + # api_task, + ) + + logger.success("系统初始化完成") + + async def _init_components(self): + """初始化其他组件""" # 启动LLM统计 self.llm_stats.start() logger.success("LLM统计功能启动成功") @@ -64,10 +78,7 @@ class MainSystem: bot_schedule.print_schedule() # 启动FastAPI服务器 - import uvicorn - - uvicorn.run(self.app, host="0.0.0.0", port=18000) - logger.success("API服务器启动成功") + self.app.register_message_handler(chat_bot.message_process) async def schedule_tasks(self): """调度定时任务""" @@ -86,6 +97,7 @@ class MainSystem: async def build_memory_task(self): """记忆构建任务""" while True: + logger.info("正在进行记忆构建") await hippocampus.operation_build_memory() await asyncio.sleep(global_config.build_memory_interval) @@ -100,6 +112,7 @@ class MainSystem: async def merge_memory_task(self): """记忆整合任务""" while True: + logger.info("正在进行记忆整合") await asyncio.sleep(global_config.build_memory_interval + 10) async def print_mood_task(self): @@ -130,8 +143,9 @@ class MainSystem: async def main(): """主函数""" system = MainSystem() - await system.initialize() - await system.schedule_tasks() + await asyncio.gather(system.initialize(), system.schedule_tasks(), system.api_task) + # await system.initialize() + # await system.schedule_tasks() if __name__ == "__main__": diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py new file mode 100644 index 000000000..56db4dfa3 --- /dev/null +++ b/src/plugins/__init__.py @@ -0,0 +1,23 @@ +""" +MaiMBot插件系统 +包含聊天、情绪、记忆、日程等功能模块 +""" + +from .chat.chat_stream import chat_manager +from .chat.emoji_manager import emoji_manager +from .chat.relationship_manager import relationship_manager +from .moods.moods import MoodManager +from .willing.willing_manager import willing_manager +from .memory_system.memory import hippocampus +from .schedule.schedule_generator import bot_schedule + +# 导出主要组件供外部使用 +__all__ = [ + "chat_manager", + "emoji_manager", + "relationship_manager", + "MoodManager", + "willing_manager", + "hippocampus", + "bot_schedule", +] diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index 56ea9408c..e9c3008b3 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -1,154 +1,15 @@ -import asyncio -import time - -from nonebot import get_driver, on_message, on_notice, require -from nonebot.adapters.onebot.v11 import Bot, MessageEvent, NoticeEvent -from nonebot.typing import T_State - -from ..moods.moods import MoodManager # 导入情绪管理器 -from ..schedule.schedule_generator import bot_schedule -from ..utils.statistic import LLMStatistics -from .bot import chat_bot -from .config import global_config from .emoji_manager import emoji_manager from .relationship_manager import relationship_manager -from ..willing.willing_manager import willing_manager from .chat_stream import chat_manager -from ..memory_system.memory import hippocampus -from .message_sender import message_manager, message_sender +from .message_sender import message_manager from .storage import MessageStorage -from src.common.logger import get_module_logger +from .config import global_config -logger = get_module_logger("chat_init") - -# 创建LLM统计实例 -llm_stats = LLMStatistics("llm_statistics.txt") - -# 添加标志变量 -_message_manager_started = False - -# 获取驱动器 -driver = get_driver() -config = driver.config - -# 初始化表情管理器 -emoji_manager.initialize() - -logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......") -# 注册消息处理器 -msg_in = on_message(priority=5) -# 注册和bot相关的通知处理器 -notice_matcher = on_notice(priority=1) -# 创建定时任务 -scheduler = require("nonebot_plugin_apscheduler").scheduler - - -@driver.on_startup -async def start_background_tasks(): - """启动后台任务""" - # 启动LLM统计 - llm_stats.start() - logger.success("LLM统计功能启动成功") - - # 初始化并启动情绪管理器 - mood_manager = MoodManager.get_instance() - mood_manager.start_mood_update(update_interval=global_config.mood_update_interval) - logger.success("情绪管理器启动成功") - - # 只启动表情包管理任务 - asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL)) - await bot_schedule.initialize() - bot_schedule.print_schedule() - - -@driver.on_startup -async def init_relationships(): - """在 NoneBot2 启动时初始化关系管理器""" - logger.debug("正在加载用户关系数据...") - await relationship_manager.load_all_relationships() - asyncio.create_task(relationship_manager._start_relationship_manager()) - - -@driver.on_bot_connect -async def _(bot: Bot): - """Bot连接成功时的处理""" - global _message_manager_started - logger.debug(f"-----------{global_config.BOT_NICKNAME}成功连接!-----------") - await willing_manager.ensure_started() - - message_sender.set_bot(bot) - logger.success("-----------消息发送器已启动!-----------") - - if not _message_manager_started: - asyncio.create_task(message_manager.start_processor()) - _message_manager_started = True - logger.success("-----------消息处理器已启动!-----------") - - asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL)) - logger.success("-----------开始偷表情包!-----------") - asyncio.create_task(chat_manager._initialize()) - asyncio.create_task(chat_manager._auto_save_task()) - - -@msg_in.handle() -async def _(bot: Bot, event: MessageEvent, state: T_State): - # 处理合并转发消息 - if "forward" in event.message: - await chat_bot.handle_forward_message(event, bot) - else: - await chat_bot.handle_message(event, bot) - - -@notice_matcher.handle() -async def _(bot: Bot, event: NoticeEvent, state: T_State): - logger.debug(f"收到通知:{event}") - await chat_bot.handle_notice(event, bot) - - -# 添加build_memory定时任务 -@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory") -async def build_memory_task(): - """每build_memory_interval秒执行一次记忆构建""" - await hippocampus.operation_build_memory() - - -@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory") -async def forget_memory_task(): - """每30秒执行一次记忆构建""" - print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...") - await hippocampus.operation_forget_topic(percentage=global_config.memory_forget_percentage) - print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成") - - -@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory") -async def merge_memory_task(): - """每30秒执行一次记忆构建""" - # print("\033[1;32m[记忆整合]\033[0m 开始整合") - # await hippocampus.operation_merge_memory(percentage=0.1) - # print("\033[1;32m[记忆整合]\033[0m 记忆整合完成") - - -@scheduler.scheduled_job("interval", seconds=30, id="print_mood") -async def print_mood_task(): - """每30秒打印一次情绪状态""" - mood_manager = MoodManager.get_instance() - mood_manager.print_mood_status() - - -@scheduler.scheduled_job("interval", seconds=7200, id="generate_schedule") -async def generate_schedule_task(): - """每2小时尝试生成一次日程""" - logger.debug("尝试生成日程") - await bot_schedule.initialize() - if not bot_schedule.enable_output: - bot_schedule.print_schedule() - - -@scheduler.scheduled_job("interval", seconds=3600, id="remove_recalled_message") -async def remove_recalled_message() -> None: - """删除撤回消息""" - try: - storage = MessageStorage() - await storage.remove_recalled_message(time.time()) - except Exception: - logger.exception("删除撤回消息失败") +__all__ = [ + "emoji_manager", + "relationship_manager", + "chat_manager", + "message_manager", + "MessageStorage", + "global_config", +] diff --git a/src/plugins/chat/api.py b/src/plugins/chat/api.py deleted file mode 100644 index 14a646832..000000000 --- a/src/plugins/chat/api.py +++ /dev/null @@ -1,54 +0,0 @@ -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel -from typing import Optional, Dict, Any -from .bot import chat_bot -from .message_cq import MessageRecvCQ -from .message_base import UserInfo, GroupInfo -from src.common.logger import get_module_logger - -logger = get_module_logger("chat_api") - -app = FastAPI() - - -class MessageRequest(BaseModel): - message_id: int - user_info: Dict[str, Any] - raw_message: str - group_info: Optional[Dict[str, Any]] = None - reply_message: Optional[Dict[str, Any]] = None - platform: str = "api" - - -@app.post("/api/message") -async def handle_message(message: MessageRequest): - try: - user_info = UserInfo( - user_id=message.user_info["user_id"], - user_nickname=message.user_info["user_nickname"], - user_cardname=message.user_info.get("user_cardname"), - platform=message.platform, - ) - - group_info = None - if message.group_info: - group_info = GroupInfo( - group_id=message.group_info["group_id"], - group_name=message.group_info.get("group_name"), - platform=message.platform, - ) - - message_cq = MessageRecvCQ( - message_id=message.message_id, - user_info=user_info, - raw_message=message.raw_message, - group_info=group_info, - reply_message=message.reply_message, - platform=message.platform, - ) - - await chat_bot.message_process(message_cq) - return {"status": "success"} - except Exception as e: - logger.exception("API处理消息时出错") - raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index aebe1e7db..905ed1cdf 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -1,16 +1,7 @@ import re import time from random import random -from nonebot.adapters.onebot.v11 import ( - Bot, - MessageEvent, - PrivateMessageEvent, - GroupMessageEvent, - NoticeEvent, - PokeNotifyEvent, - GroupRecallNoticeEvent, - FriendRecallNoticeEvent, -) +import json from ..memory_system.memory import hippocampus from ..moods.moods import MoodManager # 导入情绪管理器 @@ -18,9 +9,7 @@ from .config import global_config from .emoji_manager import emoji_manager # 导入表情包管理器 from .llm_generator import ResponseGenerator from .message import MessageSending, MessageRecv, MessageThinking, MessageSet -from .message_cq import ( - MessageRecvCQ, -) + from .chat_stream import chat_manager from .message_sender import message_manager # 导入新的消息管理器 @@ -30,7 +19,7 @@ from .utils import is_mentioned_bot_in_message from .utils_image import image_path_to_base64 from .utils_user import get_user_nickname, get_user_cardname from ..willing.willing_manager import willing_manager # 导入意愿管理器 -from .message_base import UserInfo, GroupInfo, Seg +from ..message import UserInfo, GroupInfo, Seg from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig @@ -62,7 +51,7 @@ class ChatBot: if not self._started: self._started = True - async def message_process(self, message_cq: MessageRecvCQ) -> None: + async def message_process(self, message_data: str) -> None: """处理转化后的统一格式消息 1. 过滤消息 2. 记忆激活 @@ -71,12 +60,11 @@ class ChatBot: 5. 更新关系 6. 更新情绪 """ - await message_cq.initialize() - message_json = message_cq.to_dict() + # message_json = json.loads(message_data) # 哦我嘞个json # 进入maimbot - message = MessageRecv(message_json) + message = MessageRecv(message_data) groupinfo = message.message_info.group_info userinfo = message.message_info.user_info messageinfo = message.message_info @@ -146,7 +134,7 @@ class ChatBot: response = None # 开始组织语言 - if random() < reply_probability: + if random() < reply_probability + 100: bot_user_info = UserInfo( user_id=global_config.BOT_QQ, user_nickname=global_config.BOT_NICKNAME, @@ -278,235 +266,6 @@ class ChatBot: # chat_stream=chat # ) - async def handle_notice(self, event: NoticeEvent, bot: Bot) -> None: - """处理收到的通知""" - if isinstance(event, PokeNotifyEvent): - # 戳一戳 通知 - # 不处理其他人的戳戳 - if not event.is_tome(): - return - - # 用户屏蔽,不区分私聊/群聊 - if event.user_id in global_config.ban_user_id: - return - - # 白名单模式 - if event.group_id: - if event.group_id not in global_config.talk_allowed_groups: - return - - raw_message = f"[戳了戳]{global_config.BOT_NICKNAME}" # 默认类型 - if info := event.model_extra["raw_info"]: - poke_type = info[2].get("txt", "戳了戳") # 戳戳类型,例如“拍一拍”、“揉一揉”、“捏一捏” - custom_poke_message = info[4].get("txt", "") # 自定义戳戳消息,若不存在会为空字符串 - raw_message = f"[{poke_type}]{global_config.BOT_NICKNAME}{custom_poke_message}" - - raw_message += "(这是一个类似摸摸头的友善行为,而不是恶意行为,请不要作出攻击发言)" - - user_info = UserInfo( - user_id=event.user_id, - user_nickname=(await bot.get_stranger_info(user_id=event.user_id, no_cache=True))["nickname"], - user_cardname=None, - platform="qq", - ) - - if event.group_id: - group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq") - else: - group_info = None - - message_cq = MessageRecvCQ( - message_id=0, - user_info=user_info, - raw_message=str(raw_message), - group_info=group_info, - reply_message=None, - platform="qq", - ) - - await self.message_process(message_cq) - - elif isinstance(event, GroupRecallNoticeEvent) or isinstance(event, FriendRecallNoticeEvent): - user_info = UserInfo( - user_id=event.user_id, - user_nickname=get_user_nickname(event.user_id) or None, - user_cardname=get_user_cardname(event.user_id) or None, - platform="qq", - ) - - if isinstance(event, GroupRecallNoticeEvent): - group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq") - else: - group_info = None - - chat = await chat_manager.get_or_create_stream( - platform=user_info.platform, user_info=user_info, group_info=group_info - ) - - await self.storage.store_recalled_message(event.message_id, time.time(), chat) - - async def handle_message(self, event: MessageEvent, bot: Bot) -> None: - """处理收到的消息""" - - self.bot = bot # 更新 bot 实例 - - # 用户屏蔽,不区分私聊/群聊 - if event.user_id in global_config.ban_user_id: - return - - if ( - event.reply - and hasattr(event.reply, "sender") - and hasattr(event.reply.sender, "user_id") - and event.reply.sender.user_id in global_config.ban_user_id - ): - logger.debug(f"跳过处理回复来自被ban用户 {event.reply.sender.user_id} 的消息") - return - # 处理私聊消息 - if isinstance(event, PrivateMessageEvent): - if not global_config.enable_friend_chat: # 私聊过滤 - return - else: - try: - user_info = UserInfo( - user_id=event.user_id, - user_nickname=(await bot.get_stranger_info(user_id=event.user_id, no_cache=True))["nickname"], - user_cardname=None, - platform="qq", - ) - except Exception as e: - logger.error(f"获取陌生人信息失败: {e}") - return - logger.debug(user_info) - - # group_info = GroupInfo(group_id=0, group_name="私聊", platform="qq") - group_info = None - - # 处理群聊消息 - else: - # 白名单设定由nontbot侧完成 - if event.group_id: - if event.group_id not in global_config.talk_allowed_groups: - return - - user_info = UserInfo( - user_id=event.user_id, - user_nickname=event.sender.nickname, - user_cardname=event.sender.card or None, - platform="qq", - ) - - group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq") - - # group_info = await bot.get_group_info(group_id=event.group_id) - # sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) - - message_cq = MessageRecvCQ( - message_id=event.message_id, - user_info=user_info, - raw_message=str(event.original_message), - group_info=group_info, - reply_message=event.reply, - platform="qq", - ) - - await self.message_process(message_cq) - - async def handle_forward_message(self, event: MessageEvent, bot: Bot) -> None: - """专用于处理合并转发的消息处理器""" - - # 用户屏蔽,不区分私聊/群聊 - if event.user_id in global_config.ban_user_id: - return - - if isinstance(event, GroupMessageEvent): - if event.group_id: - if event.group_id not in global_config.talk_allowed_groups: - return - - # 获取合并转发消息的详细信息 - forward_info = await bot.get_forward_msg(message_id=event.message_id) - messages = forward_info["messages"] - - # 构建合并转发消息的文本表示 - processed_messages = [] - for node in messages: - # 提取发送者昵称 - nickname = node["sender"].get("nickname", "未知用户") - - # 递归处理消息内容 - message_content = await self.process_message_segments(node["message"], layer=0) - - # 拼接为【昵称】+ 内容 - processed_messages.append(f"【{nickname}】{message_content}") - - # 组合所有消息 - combined_message = "\n".join(processed_messages) - combined_message = f"合并转发消息内容:\n{combined_message}" - - # 构建用户信息(使用转发消息的发送者) - user_info = UserInfo( - user_id=event.user_id, - user_nickname=event.sender.nickname, - user_cardname=event.sender.card if hasattr(event.sender, "card") else None, - platform="qq", - ) - - # 构建群聊信息(如果是群聊) - group_info = None - if isinstance(event, GroupMessageEvent): - group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq") - - # 创建消息对象 - message_cq = MessageRecvCQ( - message_id=event.message_id, - user_info=user_info, - raw_message=combined_message, - group_info=group_info, - reply_message=event.reply, - platform="qq", - ) - - # 进入标准消息处理流程 - await self.message_process(message_cq) - - async def process_message_segments(self, segments: list, layer: int) -> str: - """递归处理消息段""" - parts = [] - for seg in segments: - part = await self.process_segment(seg, layer + 1) - parts.append(part) - return "".join(parts) - - async def process_segment(self, seg: dict, layer: int) -> str: - """处理单个消息段""" - seg_type = seg["type"] - if layer > 3: - # 防止有那种100层转发消息炸飞麦麦 - return "【转发消息】" - if seg_type == "text": - return seg["data"]["text"] - elif seg_type == "image": - return "[图片]" - elif seg_type == "face": - return "[表情]" - elif seg_type == "at": - return f"@{seg['data'].get('qq', '未知用户')}" - elif seg_type == "forward": - # 递归处理嵌套的合并转发消息 - nested_nodes = seg["data"].get("content", []) - nested_messages = [] - nested_messages.append("合并转发消息内容:") - for node in nested_nodes: - nickname = node["sender"].get("nickname", "未知用户") - content = await self.process_message_segments(node["message"], layer=layer) - # nested_messages.append('-' * layer) - nested_messages.append(f"{'--' * layer}【{nickname}】{content}") - # nested_messages.append(f"{'--' * layer}合并转发第【{layer}】层结束") - return "\n".join(nested_messages) - else: - return f"[{seg_type}]" - # 创建全局ChatBot实例 chat_bot = ChatBot() diff --git a/src/plugins/chat/chat_stream.py b/src/plugins/chat/chat_stream.py index d5ab7b8a8..660afa290 100644 --- a/src/plugins/chat/chat_stream.py +++ b/src/plugins/chat/chat_stream.py @@ -6,7 +6,7 @@ from typing import Dict, Optional from ...common.database import db -from .message_base import GroupInfo, UserInfo +from ..message.message_base import GroupInfo, UserInfo from src.common.logger import get_module_logger diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py deleted file mode 100644 index 46b4c891f..000000000 --- a/src/plugins/chat/cq_code.py +++ /dev/null @@ -1,385 +0,0 @@ -import base64 -import html -import asyncio -from dataclasses import dataclass -from typing import Dict, List, Optional, Union -import ssl -import os -import aiohttp -from src.common.logger import get_module_logger -from nonebot import get_driver - -from ..models.utils_model import LLM_request -from .config import global_config -from .mapper import emojimapper -from .message_base import Seg -from .utils_user import get_user_nickname, get_groupname -from .message_base import GroupInfo, UserInfo - -driver = get_driver() -config = driver.config - -# 创建SSL上下文 -ssl_context = ssl.create_default_context() -ssl_context.set_ciphers("AES128-GCM-SHA256") - -logger = get_module_logger("cq_code") - - -@dataclass -class CQCode: - """ - CQ码数据类,用于存储和处理CQ码 - - 属性: - type: CQ码类型(如'image', 'at', 'face'等) - params: CQ码的参数字典 - raw_code: 原始CQ码字符串 - translated_segments: 经过处理后的Seg对象列表 - """ - - type: str - params: Dict[str, str] - group_info: Optional[GroupInfo] = None - user_info: Optional[UserInfo] = None - translated_segments: Optional[Union[Seg, List[Seg]]] = None - reply_message: Dict = None # 存储回复消息 - image_base64: Optional[str] = None - _llm: Optional[LLM_request] = None - - def __post_init__(self): - """初始化LLM实例""" - pass - - async def translate(self): - """根据CQ码类型进行相应的翻译处理,转换为Seg对象""" - if self.type == "text": - self.translated_segments = Seg(type="text", data=self.params.get("text", "")) - elif self.type == "image": - base64_data = await self.translate_image() - if base64_data: - if self.params.get("sub_type") == "0": - self.translated_segments = Seg(type="image", data=base64_data) - else: - self.translated_segments = Seg(type="emoji", data=base64_data) - else: - self.translated_segments = Seg(type="text", data="[图片]") - elif self.type == "at": - if self.params.get("qq") == "all": - self.translated_segments = Seg(type="text", data="@[全体成员]") - else: - user_nickname = get_user_nickname(self.params.get("qq", "")) - self.translated_segments = Seg(type="text", data=f"[@{user_nickname or '某人'}]") - elif self.type == "reply": - reply_segments = await self.translate_reply() - if reply_segments: - self.translated_segments = Seg(type="seglist", data=reply_segments) - else: - self.translated_segments = Seg(type="text", data="[回复某人消息]") - elif self.type == "face": - face_id = self.params.get("id", "") - self.translated_segments = Seg(type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]") - elif self.type == "forward": - forward_segments = await self.translate_forward() - if forward_segments: - self.translated_segments = Seg(type="seglist", data=forward_segments) - else: - self.translated_segments = Seg(type="text", data="[转发消息]") - else: - self.translated_segments = Seg(type="text", data=f"[{self.type}]") - - async def get_img(self) -> Optional[str]: - """异步获取图片并转换为base64""" - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/50.0.2661.87 Safari/537.36", - "Accept": "text/html, application/xhtml xml, */*", - "Accept-Encoding": "gbk, GB2312", - "Accept-Language": "zh-cn", - "Content-Type": "application/x-www-form-urlencoded", - "Cache-Control": "no-cache", - } - - url = html.unescape(self.params["url"]) - if not url.startswith(("http://", "https://")): - return None - - max_retries = 3 - for retry in range(max_retries): - try: - logger.debug(f"获取图片中: {url}") - # 设置SSL上下文和创建连接器 - conn = aiohttp.TCPConnector(ssl=ssl_context) - async with aiohttp.ClientSession(connector=conn) as session: - async with session.get( - url, - headers=headers, - timeout=aiohttp.ClientTimeout(total=15), - allow_redirects=True, - ) as response: - # 腾讯服务器特殊状态码处理 - if response.status == 400 and "multimedia.nt.qq.com.cn" in url: - return None - - if response.status != 200: - raise aiohttp.ClientError(f"HTTP {response.status}") - - # 验证内容类型 - content_type = response.headers.get("Content-Type", "") - if not content_type.startswith("image/"): - raise ValueError(f"非图片内容类型: {content_type}") - - # 读取响应内容 - content = await response.read() - logger.debug(f"获取图片成功: {url}") - - # 转换为Base64 - image_base64 = base64.b64encode(content).decode("utf-8") - self.image_base64 = image_base64 - return image_base64 - - except (aiohttp.ClientError, ValueError) as e: - if retry == max_retries - 1: - logger.error(f"最终请求失败: {str(e)}") - await asyncio.sleep(1.5**retry) # 指数退避 - - except Exception as e: - logger.exception(f"获取图片时发生未知错误: {str(e)}") - return None - - return None - - async def translate_image(self) -> Optional[str]: - """处理图片类型的CQ码,返回base64字符串""" - if "url" not in self.params: - return None - return await self.get_img() - - async def translate_forward(self) -> Optional[List[Seg]]: - """处理转发消息,返回Seg列表""" - try: - if "content" not in self.params: - return None - - content = self.unescape(self.params["content"]) - import ast - - try: - messages = ast.literal_eval(content) - except ValueError as e: - logger.error(f"解析转发消息内容失败: {str(e)}") - return None - - formatted_segments = [] - for msg in messages: - sender = msg.get("sender", {}) - nickname = sender.get("card") or sender.get("nickname", "未知用户") - raw_message = msg.get("raw_message", "") - message_array = msg.get("message", []) - - if message_array and isinstance(message_array, list): - for message_part in message_array: - if message_part.get("type") == "forward": - content_seg = Seg(type="text", data="[转发消息]") - break - else: - if raw_message: - from .message_cq import MessageRecvCQ - - user_info = UserInfo( - platform="qq", - user_id=msg.get("user_id", 0), - user_nickname=nickname, - ) - group_info = GroupInfo( - platform="qq", - group_id=msg.get("group_id", 0), - group_name=get_groupname(msg.get("group_id", 0)), - ) - - message_obj = MessageRecvCQ( - message_id=msg.get("message_id", 0), - user_info=user_info, - raw_message=raw_message, - plain_text=raw_message, - group_info=group_info, - ) - await message_obj.initialize() - content_seg = Seg(type="seglist", data=[message_obj.message_segment]) - else: - content_seg = Seg(type="text", data="[空消息]") - else: - if raw_message: - from .message_cq import MessageRecvCQ - - user_info = UserInfo( - platform="qq", - user_id=msg.get("user_id", 0), - user_nickname=nickname, - ) - group_info = GroupInfo( - platform="qq", - group_id=msg.get("group_id", 0), - group_name=get_groupname(msg.get("group_id", 0)), - ) - message_obj = MessageRecvCQ( - message_id=msg.get("message_id", 0), - user_info=user_info, - raw_message=raw_message, - plain_text=raw_message, - group_info=group_info, - ) - await message_obj.initialize() - content_seg = Seg(type="seglist", data=[message_obj.message_segment]) - else: - content_seg = Seg(type="text", data="[空消息]") - - formatted_segments.append(Seg(type="text", data=f"{nickname}: ")) - formatted_segments.append(content_seg) - formatted_segments.append(Seg(type="text", data="\n")) - - return formatted_segments - - except Exception as e: - logger.error(f"处理转发消息失败: {str(e)}") - return None - - async def translate_reply(self) -> Optional[List[Seg]]: - """处理回复类型的CQ码,返回Seg列表""" - from .message_cq import MessageRecvCQ - - if self.reply_message is None: - return None - if hasattr(self.reply_message, "group_id"): - group_info = GroupInfo(platform="qq", group_id=self.reply_message.group_id, group_name="") - else: - group_info = None - - if self.reply_message.sender.user_id: - message_obj = MessageRecvCQ( - user_info=UserInfo( - user_id=self.reply_message.sender.user_id, user_nickname=self.reply_message.sender.nickname - ), - message_id=self.reply_message.message_id, - raw_message=str(self.reply_message.message), - group_info=group_info, - ) - await message_obj.initialize() - - segments = [] - if message_obj.message_info.user_info.user_id == global_config.BOT_QQ: - segments.append(Seg(type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: ")) - else: - segments.append( - Seg( - type="text", - data=f"[回复 {self.reply_message.sender.nickname} 的消息: ", - ) - ) - - segments.append(Seg(type="seglist", data=[message_obj.message_segment])) - segments.append(Seg(type="text", data="]")) - return segments - else: - return None - - @staticmethod - def unescape(text: str) -> str: - """反转义CQ码中的特殊字符""" - return text.replace(",", ",").replace("[", "[").replace("]", "]").replace("&", "&") - - -class CQCode_tool: - @staticmethod - def cq_from_dict_to_class(cq_code: Dict, msg, reply: Optional[Dict] = None) -> CQCode: - """ - 将CQ码字典转换为CQCode对象 - - Args: - cq_code: CQ码字典 - msg: MessageCQ对象 - reply: 回复消息的字典(可选) - - Returns: - CQCode对象 - """ - # 处理字典形式的CQ码 - # 从cq_code字典中获取type字段的值,如果不存在则默认为'text' - cq_type = cq_code.get("type", "text") - params = {} - if cq_type == "text": - params["text"] = cq_code.get("data", {}).get("text", "") - else: - params = cq_code.get("data", {}) - - instance = CQCode( - type=cq_type, - params=params, - group_info=msg.message_info.group_info, - user_info=msg.message_info.user_info, - reply_message=reply, - ) - - return instance - - @staticmethod - def create_reply_cq(message_id: int) -> str: - """ - 创建回复CQ码 - Args: - message_id: 回复的消息ID - Returns: - 回复CQ码字符串 - """ - return f"[CQ:reply,id={message_id}]" - - @staticmethod - def create_emoji_cq(file_path: str) -> str: - """ - 创建表情包CQ码 - Args: - file_path: 本地表情包文件路径 - Returns: - 表情包CQ码字符串 - """ - # 确保使用绝对路径 - abs_path = os.path.abspath(file_path) - # 转义特殊字符 - escaped_path = abs_path.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",") - # 生成CQ码,设置sub_type=1表示这是表情包 - return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" - - @staticmethod - def create_emoji_cq_base64(base64_data: str) -> str: - """ - 创建表情包CQ码 - Args: - base64_data: base64编码的表情包数据 - Returns: - 表情包CQ码字符串 - """ - # 转义base64数据 - escaped_base64 = ( - base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",") - ) - # 生成CQ码,设置sub_type=1表示这是表情包 - return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]" - - @staticmethod - def create_image_cq_base64(base64_data: str) -> str: - """ - 创建表情包CQ码 - Args: - base64_data: base64编码的表情包数据 - Returns: - 表情包CQ码字符串 - """ - # 转义base64数据 - escaped_base64 = ( - base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",") - ) - # 生成CQ码,设置sub_type=1表示这是表情包 - return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]" - - -cq_code_tool = CQCode_tool() diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 683a37736..cc513734a 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -9,8 +9,6 @@ from typing import Optional, Tuple from PIL import Image import io -from nonebot import get_driver - from ...common.database import db from ..chat.config import global_config from ..chat.utils import get_embedding @@ -21,8 +19,6 @@ from src.common.logger import get_module_logger logger = get_module_logger("emoji") -driver = get_driver() -config = driver.config image_manager = ImageManager() @@ -118,9 +114,11 @@ class EmojiManager: try: # 获取所有表情包 - all_emojis = [e for e in - db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1, "blacklist": 1}) - if 'blacklist' not in e] + all_emojis = [ + e + for e in db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1, "blacklist": 1}) + if "blacklist" not in e + ] if not all_emojis: logger.warning("数据库中没有任何表情包") diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index 556f36e2e..088c6fe4d 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -2,7 +2,6 @@ import random import time from typing import List, Optional, Tuple, Union -from nonebot import get_driver from ...common.database import db from ..models.utils_model import LLM_request @@ -21,9 +20,6 @@ llm_config = LogConfig( logger = get_module_logger("llm_generator", config=llm_config) -driver = get_driver() -config = driver.config - class ResponseGenerator: def __init__(self): diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index c340a7af9..b51bcfbec 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -9,7 +9,7 @@ import urllib3 from .utils_image import image_manager -from .message_base import Seg, UserInfo, BaseMessageInfo, MessageBase +from ..message.message_base import Seg, UserInfo, BaseMessageInfo, MessageBase from .chat_stream import ChatStream from src.common.logger import get_module_logger @@ -75,19 +75,6 @@ class MessageRecv(Message): """ self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) - message_segment = message_dict.get("message_segment", {}) - - if message_segment.get("data", "") == "[json]": - # 提取json消息中的展示信息 - pattern = r"\[CQ:json,data=(?P.+?)\]" - match = re.search(pattern, message_dict.get("raw_message", "")) - raw_json = html.unescape(match.group("json_data")) - try: - json_message = json.loads(raw_json) - except json.JSONDecodeError: - json_message = {} - message_segment["data"] = json_message.get("prompt", "") - self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) self.raw_message = message_dict.get("raw_message") diff --git a/src/plugins/chat/message_cq.py b/src/plugins/chat/message_cq.py deleted file mode 100644 index e80f07e93..000000000 --- a/src/plugins/chat/message_cq.py +++ /dev/null @@ -1,170 +0,0 @@ -import time -from dataclasses import dataclass -from typing import Dict, Optional - -import urllib3 - -from .cq_code import cq_code_tool -from .utils_cq import parse_cq_code -from .utils_user import get_groupname -from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase - -# 禁用SSL警告 -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - -# 这个类是消息数据类,用于存储和管理消息数据。 -# 它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。 -# 它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。 - - -@dataclass -class MessageCQ(MessageBase): - """QQ消息基类,继承自MessageBase - - 最小必要参数: - - message_id: 消息ID - - user_id: 发送者/接收者ID - - platform: 平台标识(默认为"qq") - """ - - def __init__( - self, message_id: int, user_info: UserInfo, group_info: Optional[GroupInfo] = None, platform: str = "qq" - ): - # 构造基础消息信息 - message_info = BaseMessageInfo( - platform=platform, message_id=message_id, time=int(time.time()), group_info=group_info, user_info=user_info - ) - # 调用父类初始化,message_segment 由子类设置 - super().__init__(message_info=message_info, message_segment=None, raw_message=None) - - -@dataclass -class MessageRecvCQ(MessageCQ): - """QQ接收消息类,用于解析raw_message到Seg对象""" - - def __init__( - self, - message_id: int, - user_info: UserInfo, - raw_message: str, - group_info: Optional[GroupInfo] = None, - platform: str = "qq", - reply_message: Optional[Dict] = None, - ): - # 调用父类初始化 - super().__init__(message_id, user_info, group_info, platform) - - # 私聊消息不携带group_info - if group_info is None: - pass - elif group_info.group_name is None: - group_info.group_name = get_groupname(group_info.group_id) - - # 解析消息段 - self.message_segment = None # 初始化为None - self.raw_message = raw_message - # 异步初始化在外部完成 - - # 添加对reply的解析 - self.reply_message = reply_message - - async def initialize(self): - """异步初始化方法""" - self.message_segment = await self._parse_message(self.raw_message, self.reply_message) - - async def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg: - """异步解析消息内容为Seg对象""" - cq_code_dict_list = [] - segments = [] - - start = 0 - while True: - cq_start = message.find("[CQ:", start) - if cq_start == -1: - if start < len(message): - text = message[start:].strip() - if text: - cq_code_dict_list.append(parse_cq_code(text)) - break - - if cq_start > start: - text = message[start:cq_start].strip() - if text: - cq_code_dict_list.append(parse_cq_code(text)) - - cq_end = message.find("]", cq_start) - if cq_end == -1: - text = message[cq_start:].strip() - if text: - cq_code_dict_list.append(parse_cq_code(text)) - break - - cq_code = message[cq_start : cq_end + 1] - cq_code_dict_list.append(parse_cq_code(cq_code)) - start = cq_end + 1 - - # 转换CQ码为Seg对象 - for code_item in cq_code_dict_list: - cq_code_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message) - await cq_code_obj.translate() # 异步调用translate - if cq_code_obj.translated_segments: - segments.append(cq_code_obj.translated_segments) - - # 如果只有一个segment,直接返回 - if len(segments) == 1: - return segments[0] - - # 否则返回seglist类型的Seg - return Seg(type="seglist", data=segments) - - def to_dict(self) -> Dict: - """转换为字典格式,包含所有必要信息""" - base_dict = super().to_dict() - return base_dict - - -@dataclass -class MessageSendCQ(MessageCQ): - """QQ发送消息类,用于将Seg对象转换为raw_message""" - - def __init__(self, data: Dict): - # 调用父类初始化 - message_info = BaseMessageInfo.from_dict(data.get("message_info", {})) - message_segment = Seg.from_dict(data.get("message_segment", {})) - super().__init__( - message_info.message_id, - message_info.user_info, - message_info.group_info if message_info.group_info else None, - message_info.platform, - ) - - self.message_segment = message_segment - self.raw_message = self._generate_raw_message() - - def _generate_raw_message(self) -> str: - """将Seg对象转换为raw_message""" - segments = [] - - # 处理消息段 - if self.message_segment.type == "seglist": - for seg in self.message_segment.data: - segments.append(self._seg_to_cq_code(seg)) - else: - segments.append(self._seg_to_cq_code(self.message_segment)) - - return "".join(segments) - - def _seg_to_cq_code(self, seg: Seg) -> str: - """将单个Seg对象转换为CQ码字符串""" - if seg.type == "text": - return str(seg.data) - elif seg.type == "image": - return cq_code_tool.create_image_cq_base64(seg.data) - elif seg.type == "emoji": - return cq_code_tool.create_emoji_cq_base64(seg.data) - elif seg.type == "at": - return f"[CQ:at,qq={seg.data}]" - elif seg.type == "reply": - return cq_code_tool.create_reply_cq(int(seg.data)) - else: - return f"[{seg.data}]" diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py index d79e9e7ab..50753219e 100644 --- a/src/plugins/chat/message_sender.py +++ b/src/plugins/chat/message_sender.py @@ -3,9 +3,8 @@ import time from typing import Dict, List, Optional, Union from src.common.logger import get_module_logger -from nonebot.adapters.onebot.v11 import Bot from ...common.database import db -from .message_cq import MessageSendCQ +from ..message.api import global_api from .message import MessageSending, MessageThinking, MessageSet from .storage import MessageStorage @@ -32,9 +31,9 @@ class Message_Sender: self.last_send_time = 0 self._current_bot = None - def set_bot(self, bot: Bot): + def set_bot(self, bot): """设置当前bot实例""" - self._current_bot = bot + pass def get_recalled_messages(self, stream_id: str) -> list: """获取所有撤回的消息""" @@ -60,31 +59,14 @@ class Message_Sender: break if not is_recalled: message_json = message.to_dict() - message_send = MessageSendCQ(data=message_json) + message_preview = truncate_message(message.processed_plain_text) - if message_send.message_info.group_info and message_send.message_info.group_info.group_id: - try: - await self._current_bot.send_group_msg( - group_id=message.message_info.group_info.group_id, - message=message_send.raw_message, - auto_escape=False, - ) + try: + result = await global_api.send_message("http://127.0.0.1:18002/api/message", message_json) + if result["status"] == "success": logger.success(f"发送消息“{message_preview}”成功") - except Exception as e: - logger.error(f"[调试] 发生错误 {e}") - logger.error(f"[调试] 发送消息“{message_preview}”失败") - else: - try: - logger.debug(message.message_info.user_info) - await self._current_bot.send_private_msg( - user_id=message.sender_info.user_id, - message=message_send.raw_message, - auto_escape=False, - ) - logger.success(f"发送消息“{message_preview}”成功") - except Exception as e: - logger.error(f"[调试] 发生错误 {e}") - logger.error(f"[调试] 发送消息“{message_preview}”失败") + except Exception as e: + logger.error(f"发送消息“{message_preview}”失败: {str(e)}") class MessageContainer: diff --git a/src/plugins/chat/relationship_manager.py b/src/plugins/chat/relationship_manager.py index 53cb0abbf..11113804a 100644 --- a/src/plugins/chat/relationship_manager.py +++ b/src/plugins/chat/relationship_manager.py @@ -3,7 +3,7 @@ from typing import Optional from src.common.logger import get_module_logger from ...common.database import db -from .message_base import UserInfo +from ..message.message_base import UserInfo from .chat_stream import ChatStream import math from bson.decimal128 import Decimal128 @@ -122,11 +122,15 @@ class RelationshipManager: relationship.relationship_value = float(relationship.relationship_value.to_decimal()) else: relationship.relationship_value = float(relationship.relationship_value) - logger.info(f"[关系管理] 用户 {user_id}({platform}) 的关系值已转换为double类型: {relationship.relationship_value}") + logger.info( + f"[关系管理] 用户 {user_id}({platform}) 的关系值已转换为double类型: {relationship.relationship_value}" + ) except (ValueError, TypeError): # 如果不能解析/强转则将relationship.relationship_value设置为double类型的0 relationship.relationship_value = 0.0 - logger.warning(f"[关系管理] 用户 {user_id}({platform}) 的关系值无法转换为double类型,已设置为0") + logger.warning( + f"[关系管理] 用户 {user_id}({platform}) 的关系值无法转换为double类型,已设置为0" + ) relationship.relationship_value += value await self.storage_relationship(relationship) relationship.saved = True diff --git a/src/plugins/chat/topic_identifier.py b/src/plugins/chat/topic_identifier.py index 6e11bc9d7..b15c855a2 100644 --- a/src/plugins/chat/topic_identifier.py +++ b/src/plugins/chat/topic_identifier.py @@ -1,6 +1,5 @@ from typing import List, Optional -from nonebot import get_driver from ..models.utils_model import LLM_request from .config import global_config @@ -15,9 +14,6 @@ topic_config = LogConfig( logger = get_module_logger("topic_identifier", config=topic_config) -driver = get_driver() -config = driver.config - class TopicIdentifier: def __init__(self): diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 0d63e7afc..545a84108 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -7,20 +7,17 @@ from typing import Dict, List import jieba import numpy as np -from nonebot import get_driver from src.common.logger import get_module_logger from ..models.utils_model import LLM_request from ..utils.typo_generator import ChineseTypoGenerator from .config import global_config from .message import MessageRecv, Message -from .message_base import UserInfo +from ..message.message_base import UserInfo from .chat_stream import ChatStream from ..moods.moods import MoodManager from ...common.database import db -driver = get_driver() -config = driver.config logger = get_module_logger("chat_utils") @@ -291,7 +288,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: for sentence in sentences: parts = sentence.split(",") current_sentence = parts[0] - if not is_western_paragraph(current_sentence): + if not is_western_paragraph(current_sentence): for part in parts[1:]: if random.random() < split_strength: new_sentences.append(current_sentence.strip()) @@ -323,7 +320,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: for sentence in sentences: sentence = sentence.rstrip(",,") # 西文字符句子不进行随机合并 - if not is_western_paragraph(current_sentence): + if not is_western_paragraph(current_sentence): if random.random() < split_strength * 0.5: sentence = sentence.replace(",", "").replace(",", "") elif random.random() < split_strength: @@ -364,10 +361,10 @@ def random_remove_punctuation(text: str) -> str: def process_llm_response(text: str) -> List[str]: # processed_response = process_text_with_typos(content) # 对西文字符段落的回复长度设置为汉字字符的两倍 - if len(text) > 100 and not is_western_paragraph(text) : + if len(text) > 100 and not is_western_paragraph(text): logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复") return ["懒得说"] - elif len(text) > 200 : + elif len(text) > 200: logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复") return ["懒得说"] # 处理长消息 @@ -530,12 +527,12 @@ def recover_kaomoji(sentences, placeholder_to_kaomoji): recovered_sentences.append(sentence) return recovered_sentences - + def is_western_char(char): """检测是否为西文字符""" - return len(char.encode('utf-8')) <= 2 + return len(char.encode("utf-8")) <= 2 + def is_western_paragraph(paragraph): """检测是否为西文字符段落""" return all(is_western_char(char) for char in paragraph if char.isalnum()) - \ No newline at end of file diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 7e20b35db..8bbd9e33c 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -6,7 +6,6 @@ from typing import Optional from PIL import Image import io -from nonebot import get_driver from ...common.database import db from ..chat.config import global_config @@ -16,9 +15,6 @@ from src.common.logger import get_module_logger logger = get_module_logger("chat_image") -driver = get_driver() -config = driver.config - class ImageManager: _instance = None diff --git a/src/plugins/config_reload/__init__.py b/src/plugins/config_reload/__init__.py index a802f8822..8b1378917 100644 --- a/src/plugins/config_reload/__init__.py +++ b/src/plugins/config_reload/__init__.py @@ -1,11 +1 @@ -from nonebot import get_app -from .api import router -from src.common.logger import get_module_logger -# 获取主应用实例并挂载路由 -app = get_app() -app.include_router(router, prefix="/api") - -# 打印日志,方便确认API已注册 -logger = get_module_logger("cfg_reload") -logger.success("配置重载API已注册,可通过 /api/reload-config 访问") diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index 5aeb3d85a..a5464c52d 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -8,7 +8,6 @@ import re import jieba import networkx as nx -from nonebot import get_driver from ...common.database import db from ..chat.config import global_config from ..chat.utils import ( @@ -232,13 +231,13 @@ class Hippocampus: # 创建双峰分布的记忆调度器 scheduler = MemoryBuildScheduler( - n_hours1=global_config.memory_build_distribution[0], # 第一个分布均值(4小时前) - std_hours1=global_config.memory_build_distribution[1], # 第一个分布标准差 - weight1=global_config.memory_build_distribution[2], # 第一个分布权重 60% - n_hours2=global_config.memory_build_distribution[3], # 第二个分布均值(24小时前) - std_hours2=global_config.memory_build_distribution[4], # 第二个分布标准差 - weight2=global_config.memory_build_distribution[5], # 第二个分布权重 40% - total_samples=global_config.build_memory_sample_num # 总共生成10个时间点 + n_hours1=global_config.memory_build_distribution[0], # 第一个分布均值(4小时前) + std_hours1=global_config.memory_build_distribution[1], # 第一个分布标准差 + weight1=global_config.memory_build_distribution[2], # 第一个分布权重 60% + n_hours2=global_config.memory_build_distribution[3], # 第二个分布均值(24小时前) + std_hours2=global_config.memory_build_distribution[4], # 第二个分布标准差 + weight2=global_config.memory_build_distribution[5], # 第二个分布权重 40% + total_samples=global_config.build_memory_sample_num, # 总共生成10个时间点 ) # 生成时间戳数组 @@ -250,9 +249,7 @@ class Hippocampus: chat_samples = [] for timestamp in timestamps: messages = self.random_get_msg_snippet( - timestamp, - global_config.build_memory_sample_length, - max_memorized_time_per_msg + timestamp, global_config.build_memory_sample_length, max_memorized_time_per_msg ) if messages: time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 @@ -297,16 +294,16 @@ class Hippocampus: topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num)) # 使用正则表达式提取<>中的内容 - topics = re.findall(r'<([^>]+)>', topics_response[0]) - + topics = re.findall(r"<([^>]+)>", topics_response[0]) + # 如果没有找到<>包裹的内容,返回['none'] if not topics: - topics = ['none'] + topics = ["none"] else: # 处理提取出的话题 topics = [ topic.strip() - for topic in ','.join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") + for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip() ] @@ -314,8 +311,7 @@ class Hippocampus: # any()检查topic中是否包含任何一个filter_keywords中的关键词 # 只保留不包含禁用关键词的topic filtered_topics = [ - topic for topic in topics - if not any(keyword in topic for keyword in global_config.memory_ban_words) + topic for topic in topics if not any(keyword in topic for keyword in global_config.memory_ban_words) ] logger.debug(f"过滤后话题: {filtered_topics}") @@ -331,14 +327,14 @@ class Hippocampus: # 初始化压缩后的记忆集合和相似主题字典 compressed_memory = set() # 存储压缩后的(主题,内容)元组 similar_topics_dict = {} # 存储每个话题的相似主题列表 - + # 遍历每个主题及其对应的LLM任务 for topic, task in tasks: response = await task if response: # 将主题和LLM生成的内容添加到压缩记忆中 compressed_memory.add((topic, response[0])) - + # 为当前主题寻找相似的已存在主题 existing_topics = list(self.memory_graph.G.nodes()) similar_topics = [] @@ -404,7 +400,7 @@ class Hippocampus: logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}") all_added_nodes.extend(topic for topic, _ in compressed_memory) # all_connected_nodes.extend(topic for topic, _ in similar_topics_dict) - + for topic, memory in compressed_memory: self.memory_graph.add_dot(topic, memory) all_topics.append(topic) @@ -415,13 +411,13 @@ class Hippocampus: for similar_topic, similarity in similar_topics: if topic != similar_topic: strength = int(similarity * 10) - + logger.debug(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})") all_added_edges.append(f"{topic}-{similar_topic}") - + all_connected_nodes.append(topic) all_connected_nodes.append(similar_topic) - + self.memory_graph.G.add_edge( topic, similar_topic, @@ -442,11 +438,10 @@ class Hippocampus: logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") # logger.success(f"强化连接: {', '.join(all_added_edges)}") self.sync_memory_to_db() - + end_time = time.time() logger.success( - f"--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} " - "秒--------------------------" + f"--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒--------------------------" ) def sync_memory_to_db(self): @@ -800,16 +795,16 @@ class Hippocampus: topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 4)) # 使用正则表达式提取<>中的内容 print(f"话题: {topics_response[0]}") - topics = re.findall(r'<([^>]+)>', topics_response[0]) - + topics = re.findall(r"<([^>]+)>", topics_response[0]) + # 如果没有找到<>包裹的内容,返回['none'] if not topics: - topics = ['none'] + topics = ["none"] else: # 处理提取出的话题 topics = [ topic.strip() - for topic in ','.join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") + for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip() ] @@ -885,7 +880,7 @@ class Hippocampus: # 识别主题 identified_topics = await self._identify_topics(text) print(f"识别主题: {identified_topics}") - + if identified_topics[0] == "none": return 0 @@ -946,7 +941,7 @@ class Hippocampus: # 计算最终激活值 activation = int((topic_match + average_similarities) / 2 * 100) - + logger.info(f"识别<{text[:15]}...>主题: {identified_topics}, 匹配率: {topic_match:.3f}, 激活值: {activation}") return activation @@ -994,9 +989,6 @@ def segment_text(text): return seg_text -driver = get_driver() -config = driver.config - start_time = time.time() # 创建记忆图 diff --git a/src/plugins/message/__init__.py b/src/plugins/message/__init__.py new file mode 100644 index 000000000..bee5c5e58 --- /dev/null +++ b/src/plugins/message/__init__.py @@ -0,0 +1,26 @@ +"""Maim Message - A message handling library""" + +__version__ = "0.1.0" + +from .api import BaseMessageAPI, global_api +from .message_base import ( + Seg, + GroupInfo, + UserInfo, + FormatInfo, + TemplateInfo, + BaseMessageInfo, + MessageBase, +) + +__all__ = [ + "BaseMessageAPI", + "Seg", + "global_api", + "GroupInfo", + "UserInfo", + "FormatInfo", + "TemplateInfo", + "BaseMessageInfo", + "MessageBase", +] diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py new file mode 100644 index 000000000..355817efb --- /dev/null +++ b/src/plugins/message/api.py @@ -0,0 +1,86 @@ +from fastapi import FastAPI, HTTPException +from typing import Optional, Dict, Any, Callable, List +import aiohttp +import asyncio +import uvicorn +import os + + +class BaseMessageAPI: + def __init__(self, host: str = "0.0.0.0", port: int = 18000): + self.app = FastAPI() + self.host = host + self.port = port + self.message_handlers: List[Callable] = [] + self._setup_routes() + self._running = False + + def _setup_routes(self): + """设置基础路由""" + + @self.app.post("/api/message") + async def handle_message(message: Dict[str, Any]): + # try: + for handler in self.message_handlers: + await handler(message) + return {"status": "success"} + # except Exception as e: + # raise HTTPException(status_code=500, detail=str(e)) from e + + def register_message_handler(self, handler: Callable): + """注册消息处理函数""" + self.message_handlers.append(handler) + + async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: + """发送消息到指定端点""" + async with aiohttp.ClientSession() as session: + try: + async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response: + return await response.json() + except Exception as e: + # logger.error(f"发送消息失败: {str(e)}") + pass + + def run_sync(self): + """同步方式运行服务器""" + uvicorn.run(self.app, host=self.host, port=self.port) + + async def run(self): + """异步方式运行服务器""" + config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") + self.server = uvicorn.Server(config) + + await self.server.serve() + + async def start_server(self): + """启动服务器的异步方法""" + if not self._running: + self._running = True + await self.run() + + async def stop(self): + """停止服务器""" + if hasattr(self, "server"): + self._running = False + # 正确关闭 uvicorn 服务器 + self.server.should_exit = True + await self.server.shutdown() + # 等待服务器完全停止 + if hasattr(self.server, "started") and self.server.started: + await self.server.main_loop() + # 清理处理程序 + self.message_handlers.clear() + + def start(self): + """启动服务器的便捷方法""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(self.start_server()) + except KeyboardInterrupt: + pass + finally: + loop.close() + + +global_api = BaseMessageAPI(host=os.environ["HOST"], port=os.environ["PORT"]) diff --git a/src/plugins/chat/message_base.py b/src/plugins/message/message_base.py similarity index 73% rename from src/plugins/chat/message_base.py rename to src/plugins/message/message_base.py index 8ad1a9922..461fe0167 100644 --- a/src/plugins/chat/message_base.py +++ b/src/plugins/message/message_base.py @@ -103,6 +103,63 @@ class UserInfo: ) +@dataclass +class FormatInfo: + """格式信息类""" + + """ + 目前maimcore可接受的格式为text,image,emoji + 可发送的格式为text,emoji,reply + """ + + content_format: Optional[str] = None + accept_format: Optional[str] = None + + def to_dict(self) -> Dict: + """转换为字典格式""" + return {k: v for k, v in asdict(self).items() if v is not None} + + @classmethod + def from_dict(cls, data: Dict) -> "FormatInfo": + """从字典创建FormatInfo实例 + Args: + data: 包含必要字段的字典 + Returns: + FormatInfo: 新的实例 + """ + return cls( + content_format=data.get("content_format"), + accept_format=data.get("accept_format"), + ) + + +@dataclass +class TemplateInfo: + """模板信息类""" + + template_items: Optional[List[Dict]] = None + template_name: Optional[str] = None + template_default: bool = True + + def to_dict(self) -> Dict: + """转换为字典格式""" + return {k: v for k, v in asdict(self).items() if v is not None} + + @classmethod + def from_dict(cls, data: Dict) -> "TemplateInfo": + """从字典创建TemplateInfo实例 + Args: + data: 包含必要字段的字典 + Returns: + TemplateInfo: 新的实例 + """ + return cls( + template_items=data.get("template_items"), + template_name=data.get("template_name"), + template_default=data.get("template_default", True), + ) + + @dataclass class BaseMessageInfo: """消息信息类""" @@ -112,13 +169,15 @@ class BaseMessageInfo: time: Optional[int] = None group_info: Optional[GroupInfo] = None user_info: Optional[UserInfo] = None + format_info: Optional[FormatInfo] = None + template_info: Optional[TemplateInfo] = None def to_dict(self) -> Dict: """转换为字典格式""" result = {} for field, value in asdict(self).items(): if value is not None: - if isinstance(value, (GroupInfo, UserInfo)): + if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)): result[field] = value.to_dict() else: result[field] = value @@ -136,12 +195,16 @@ class BaseMessageInfo: """ group_info = GroupInfo.from_dict(data.get("group_info", {})) user_info = UserInfo.from_dict(data.get("user_info", {})) + format_info = FormatInfo.from_dict(data.get("format_info", {})) + template_info = TemplateInfo.from_dict(data.get("template_info", {})) return cls( platform=data.get("platform"), message_id=data.get("message_id"), time=data.get("time"), group_info=group_info, user_info=user_info, + format_info=format_info, + template_info=template_info, ) diff --git a/src/plugins/message/test.py b/src/plugins/message/test.py new file mode 100644 index 000000000..bc4ba4d8c --- /dev/null +++ b/src/plugins/message/test.py @@ -0,0 +1,98 @@ +import unittest +import asyncio +import aiohttp +from api import BaseMessageAPI +from message_base import ( + BaseMessageInfo, + UserInfo, + GroupInfo, + FormatInfo, + TemplateInfo, + MessageBase, + Seg, +) + + +send_url = "http://localhost" +receive_port = 18002 # 接收消息的端口 +send_port = 18000 # 发送消息的端口 +test_endpoint = "/api/message" + +# 创建并启动API实例 +api = BaseMessageAPI(host="0.0.0.0", port=receive_port) + + +class TestLiveAPI(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + """测试前的设置""" + self.received_messages = [] + + async def message_handler(message): + self.received_messages.append(message) + + self.api = api + self.api.register_message_handler(message_handler) + self.server_task = asyncio.create_task(self.api.run()) + try: + await asyncio.wait_for(asyncio.sleep(1), timeout=5) + except asyncio.TimeoutError: + self.skipTest("服务器启动超时") + + async def asyncTearDown(self): + """测试后的清理""" + if hasattr(self, "server_task"): + await self.api.stop() # 先调用正常的停止流程 + if not self.server_task.done(): + self.server_task.cancel() + try: + await asyncio.wait_for(self.server_task, timeout=100) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + + async def test_send_and_receive_message(self): + """测试向运行中的API发送消息并接收响应""" + # 准备测试消息 + user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq") + group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq") + format_info = FormatInfo( + content_format=["text"], accept_format=["text", "emoji", "reply"] + ) + template_info = None + message_info = BaseMessageInfo( + platform="qq", + message_id=12345678, + time=12345678, + group_info=group_info, + user_info=user_info, + format_info=format_info, + template_info=template_info, + ) + message = MessageBase( + message_info=message_info, + raw_message="测试消息", + message_segment=Seg(type="text", data="测试消息"), + ) + test_message = message.to_dict() + + # 发送测试消息到发送端口 + async with aiohttp.ClientSession() as session: + async with session.post( + f"{send_url}:{send_port}{test_endpoint}", + json=test_message, + ) as response: + response_data = await response.json() + self.assertEqual(response.status, 200) + self.assertEqual(response_data["status"], "success") + try: + async with asyncio.timeout(5): # 设置5秒超时 + while len(self.received_messages) == 0: + await asyncio.sleep(0.1) + received_message = self.received_messages[0] + print(received_message) + self.received_messages.clear() + except asyncio.TimeoutError: + self.fail("等待接收消息超时") + + +if __name__ == "__main__": + unittest.main() diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 5ad69ff25..578313f06 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -6,15 +6,13 @@ from typing import Tuple, Union import aiohttp from src.common.logger import get_module_logger -from nonebot import get_driver import base64 from PIL import Image import io +import os from ...common.database import db from ..chat.config import global_config -driver = get_driver() -config = driver.config logger = get_module_logger("model_utils") @@ -34,8 +32,9 @@ class LLM_request: def __init__(self, model, **kwargs): # 将大写的配置键转换为小写并从config中获取实际值 try: - self.api_key = getattr(config, model["key"]) - self.base_url = getattr(config, model["base_url"]) + self.api_key = os.environ[model["key"]] + self.base_url = os.environ[model["base_url"]] + print(self.api_key, self.base_url) except AttributeError as e: logger.error(f"原始 model dict 信息:{model}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index b26b29549..e14cc014a 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -3,7 +3,6 @@ import json import re from typing import Dict, Union -from nonebot import get_driver # 添加项目根目录到 Python 路径 @@ -14,9 +13,6 @@ from src.common.logger import get_module_logger logger = get_module_logger("scheduler") -driver = get_driver() -config = driver.config - class ScheduleGenerator: enable_output: bool = True @@ -183,5 +179,7 @@ class ScheduleGenerator: logger.info(f"时间[{time_str}]: 活动[{activity}]") logger.info("==================") self.enable_output = False + + # 当作为组件导入时使用的实例 bot_schedule = ScheduleGenerator() diff --git a/如果你更新了版本,点我.txt b/如果你更新了版本,点我.txt deleted file mode 100644 index 400e8ae0c..000000000 --- a/如果你更新了版本,点我.txt +++ /dev/null @@ -1,4 +0,0 @@ -更新版本后,建议删除数据库messages中所有内容,不然会出现报错 -该操作不会影响你的记忆 - -如果显示配置文件版本过低,运行根目录的bat \ No newline at end of file