refactor: 初步重构为maimcore
This commit is contained in:
61
bot.py
61
bot.py
@@ -4,15 +4,11 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import nonebot
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from nonebot.adapters.onebot.v11 import Adapter
|
|
||||||
import platform
|
import platform
|
||||||
|
from dotenv import load_dotenv
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
from src.main import MainSystem
|
||||||
|
|
||||||
logger = get_module_logger("main_bot")
|
logger = get_module_logger("main_bot")
|
||||||
|
|
||||||
@@ -134,11 +130,7 @@ def scan_provider(env_config: dict):
|
|||||||
|
|
||||||
async def graceful_shutdown():
|
async def graceful_shutdown():
|
||||||
try:
|
try:
|
||||||
global uvicorn_server
|
logger.info("正在优雅关闭麦麦...")
|
||||||
if uvicorn_server:
|
|
||||||
uvicorn_server.force_exit = True # 强制退出
|
|
||||||
await uvicorn_server.shutdown()
|
|
||||||
|
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
@@ -148,22 +140,6 @@ async def graceful_shutdown():
|
|||||||
logger.error(f"麦麦关闭失败: {e}")
|
logger.error(f"麦麦关闭失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def uvicorn_main():
|
|
||||||
global uvicorn_server
|
|
||||||
config = uvicorn.Config(
|
|
||||||
app="__main__:app",
|
|
||||||
host=os.getenv("HOST", "127.0.0.1"),
|
|
||||||
port=int(os.getenv("PORT", 8080)),
|
|
||||||
reload=os.getenv("ENVIRONMENT") == "dev",
|
|
||||||
timeout_graceful_shutdown=5,
|
|
||||||
log_config=None,
|
|
||||||
access_log=False,
|
|
||||||
)
|
|
||||||
server = uvicorn.Server(config)
|
|
||||||
uvicorn_server = server
|
|
||||||
await server.serve()
|
|
||||||
|
|
||||||
|
|
||||||
def check_eula():
|
def check_eula():
|
||||||
eula_confirm_file = Path("eula.confirmed")
|
eula_confirm_file = Path("eula.confirmed")
|
||||||
privacy_confirm_file = Path("privacy.confirmed")
|
privacy_confirm_file = Path("privacy.confirmed")
|
||||||
@@ -245,7 +221,6 @@ def check_eula():
|
|||||||
|
|
||||||
def raw_main():
|
def raw_main():
|
||||||
# 利用 TZ 环境变量设定程序工作的时区
|
# 利用 TZ 环境变量设定程序工作的时区
|
||||||
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
|
|
||||||
if platform.system().lower() != "windows":
|
if platform.system().lower() != "windows":
|
||||||
time.tzset()
|
time.tzset()
|
||||||
|
|
||||||
@@ -256,40 +231,26 @@ def raw_main():
|
|||||||
init_env()
|
init_env()
|
||||||
load_env()
|
load_env()
|
||||||
|
|
||||||
# load_logger()
|
|
||||||
|
|
||||||
env_config = {key: os.getenv(key) for key in os.environ}
|
env_config = {key: os.getenv(key) for key in os.environ}
|
||||||
scan_provider(env_config)
|
scan_provider(env_config)
|
||||||
|
|
||||||
# 设置基础配置
|
# 返回MainSystem实例
|
||||||
base_config = {
|
return MainSystem()
|
||||||
"websocket_port": int(env_config.get("PORT", 8080)),
|
|
||||||
"host": env_config.get("HOST", "127.0.0.1"),
|
|
||||||
"log_level": "INFO",
|
|
||||||
}
|
|
||||||
|
|
||||||
# 合并配置
|
|
||||||
nonebot.init(**base_config, **env_config)
|
|
||||||
|
|
||||||
# 注册适配器
|
|
||||||
global driver
|
|
||||||
driver = nonebot.get_driver()
|
|
||||||
driver.register_adapter(Adapter)
|
|
||||||
|
|
||||||
# 加载插件
|
|
||||||
nonebot.load_plugins("src/plugins")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
raw_main()
|
# 获取MainSystem实例
|
||||||
|
main_system = raw_main()
|
||||||
|
|
||||||
app = nonebot.get_asgi()
|
# 创建事件循环
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(uvicorn_main())
|
# 执行初始化和任务调度
|
||||||
|
loop.run_until_complete(main_system.initialize())
|
||||||
|
loop.run_until_complete(main_system.schedule_tasks())
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.warning("收到中断信号,正在优雅关闭...")
|
logger.warning("收到中断信号,正在优雅关闭...")
|
||||||
loop.run_until_complete(graceful_shutdown())
|
loop.run_until_complete(graceful_shutdown())
|
||||||
|
|||||||
60
src/main.py
60
src/main.py
@@ -1,21 +1,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from .plugins.utils.statistic import LLMStatistics
|
||||||
from plugins.utils.statistic import LLMStatistics
|
from .plugins.moods.moods import MoodManager
|
||||||
from plugins.moods.moods import MoodManager
|
from .plugins.schedule.schedule_generator import bot_schedule
|
||||||
from plugins.schedule.schedule_generator import bot_schedule
|
from .plugins.chat.emoji_manager import emoji_manager
|
||||||
from plugins.chat.emoji_manager import emoji_manager
|
from .plugins.chat.relationship_manager import relationship_manager
|
||||||
from plugins.chat.relationship_manager import relationship_manager
|
from .plugins.willing.willing_manager import willing_manager
|
||||||
from plugins.willing.willing_manager import willing_manager
|
from .plugins.chat.chat_stream import chat_manager
|
||||||
from plugins.chat.chat_stream import chat_manager
|
from .plugins.memory_system.memory import hippocampus
|
||||||
from plugins.memory_system.memory import hippocampus
|
from .plugins.chat.message_sender import message_manager
|
||||||
from plugins.chat.message_sender import message_manager
|
from .plugins.chat.storage import MessageStorage
|
||||||
from plugins.chat.storage import MessageStorage
|
from .plugins.chat.config import global_config
|
||||||
from plugins.chat.config import global_config
|
from .plugins.chat.bot import chat_bot
|
||||||
from common.logger import get_module_logger
|
from .common.logger import get_module_logger
|
||||||
from fastapi import FastAPI
|
|
||||||
from plugins.chat.api import app as api_app
|
|
||||||
|
|
||||||
logger = get_module_logger("main")
|
logger = get_module_logger("main")
|
||||||
|
|
||||||
@@ -25,13 +23,29 @@ class MainSystem:
|
|||||||
self.llm_stats = LLMStatistics("llm_statistics.txt")
|
self.llm_stats = LLMStatistics("llm_statistics.txt")
|
||||||
self.mood_manager = MoodManager.get_instance()
|
self.mood_manager = MoodManager.get_instance()
|
||||||
self._message_manager_started = False
|
self._message_manager_started = False
|
||||||
self.app = FastAPI()
|
|
||||||
self.app.mount("/chat", api_app)
|
# 使用消息API替代直接的FastAPI实例
|
||||||
|
from .plugins.message import global_api
|
||||||
|
|
||||||
|
self.app = global_api
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化系统组件"""
|
"""初始化系统组件"""
|
||||||
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
|
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
|
||||||
|
|
||||||
|
# 启动API服务器(改为异步启动)
|
||||||
|
self.api_task = asyncio.create_task(self.app.run())
|
||||||
|
|
||||||
|
# 其他初始化任务
|
||||||
|
await asyncio.gather(
|
||||||
|
self._init_components(), # 将原有的初始化代码移到这个新方法中
|
||||||
|
# api_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.success("系统初始化完成")
|
||||||
|
|
||||||
|
async def _init_components(self):
|
||||||
|
"""初始化其他组件"""
|
||||||
# 启动LLM统计
|
# 启动LLM统计
|
||||||
self.llm_stats.start()
|
self.llm_stats.start()
|
||||||
logger.success("LLM统计功能启动成功")
|
logger.success("LLM统计功能启动成功")
|
||||||
@@ -64,10 +78,7 @@ class MainSystem:
|
|||||||
bot_schedule.print_schedule()
|
bot_schedule.print_schedule()
|
||||||
|
|
||||||
# 启动FastAPI服务器
|
# 启动FastAPI服务器
|
||||||
import uvicorn
|
self.app.register_message_handler(chat_bot.message_process)
|
||||||
|
|
||||||
uvicorn.run(self.app, host="0.0.0.0", port=18000)
|
|
||||||
logger.success("API服务器启动成功")
|
|
||||||
|
|
||||||
async def schedule_tasks(self):
|
async def schedule_tasks(self):
|
||||||
"""调度定时任务"""
|
"""调度定时任务"""
|
||||||
@@ -86,6 +97,7 @@ class MainSystem:
|
|||||||
async def build_memory_task(self):
|
async def build_memory_task(self):
|
||||||
"""记忆构建任务"""
|
"""记忆构建任务"""
|
||||||
while True:
|
while True:
|
||||||
|
logger.info("正在进行记忆构建")
|
||||||
await hippocampus.operation_build_memory()
|
await hippocampus.operation_build_memory()
|
||||||
await asyncio.sleep(global_config.build_memory_interval)
|
await asyncio.sleep(global_config.build_memory_interval)
|
||||||
|
|
||||||
@@ -100,6 +112,7 @@ class MainSystem:
|
|||||||
async def merge_memory_task(self):
|
async def merge_memory_task(self):
|
||||||
"""记忆整合任务"""
|
"""记忆整合任务"""
|
||||||
while True:
|
while True:
|
||||||
|
logger.info("正在进行记忆整合")
|
||||||
await asyncio.sleep(global_config.build_memory_interval + 10)
|
await asyncio.sleep(global_config.build_memory_interval + 10)
|
||||||
|
|
||||||
async def print_mood_task(self):
|
async def print_mood_task(self):
|
||||||
@@ -130,8 +143,9 @@ class MainSystem:
|
|||||||
async def main():
|
async def main():
|
||||||
"""主函数"""
|
"""主函数"""
|
||||||
system = MainSystem()
|
system = MainSystem()
|
||||||
await system.initialize()
|
await asyncio.gather(system.initialize(), system.schedule_tasks(), system.api_task)
|
||||||
await system.schedule_tasks()
|
# await system.initialize()
|
||||||
|
# await system.schedule_tasks()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
23
src/plugins/__init__.py
Normal file
23
src/plugins/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""
|
||||||
|
MaiMBot插件系统
|
||||||
|
包含聊天、情绪、记忆、日程等功能模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .chat.chat_stream import chat_manager
|
||||||
|
from .chat.emoji_manager import emoji_manager
|
||||||
|
from .chat.relationship_manager import relationship_manager
|
||||||
|
from .moods.moods import MoodManager
|
||||||
|
from .willing.willing_manager import willing_manager
|
||||||
|
from .memory_system.memory import hippocampus
|
||||||
|
from .schedule.schedule_generator import bot_schedule
|
||||||
|
|
||||||
|
# 导出主要组件供外部使用
|
||||||
|
__all__ = [
|
||||||
|
"chat_manager",
|
||||||
|
"emoji_manager",
|
||||||
|
"relationship_manager",
|
||||||
|
"MoodManager",
|
||||||
|
"willing_manager",
|
||||||
|
"hippocampus",
|
||||||
|
"bot_schedule",
|
||||||
|
]
|
||||||
@@ -1,154 +1,15 @@
|
|||||||
import asyncio
|
|
||||||
import time
|
|
||||||
|
|
||||||
from nonebot import get_driver, on_message, on_notice, require
|
|
||||||
from nonebot.adapters.onebot.v11 import Bot, MessageEvent, NoticeEvent
|
|
||||||
from nonebot.typing import T_State
|
|
||||||
|
|
||||||
from ..moods.moods import MoodManager # 导入情绪管理器
|
|
||||||
from ..schedule.schedule_generator import bot_schedule
|
|
||||||
from ..utils.statistic import LLMStatistics
|
|
||||||
from .bot import chat_bot
|
|
||||||
from .config import global_config
|
|
||||||
from .emoji_manager import emoji_manager
|
from .emoji_manager import emoji_manager
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from ..willing.willing_manager import willing_manager
|
|
||||||
from .chat_stream import chat_manager
|
from .chat_stream import chat_manager
|
||||||
from ..memory_system.memory import hippocampus
|
from .message_sender import message_manager
|
||||||
from .message_sender import message_manager, message_sender
|
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
from src.common.logger import get_module_logger
|
from .config import global_config
|
||||||
|
|
||||||
logger = get_module_logger("chat_init")
|
__all__ = [
|
||||||
|
"emoji_manager",
|
||||||
# 创建LLM统计实例
|
"relationship_manager",
|
||||||
llm_stats = LLMStatistics("llm_statistics.txt")
|
"chat_manager",
|
||||||
|
"message_manager",
|
||||||
# 添加标志变量
|
"MessageStorage",
|
||||||
_message_manager_started = False
|
"global_config",
|
||||||
|
]
|
||||||
# 获取驱动器
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
# 初始化表情管理器
|
|
||||||
emoji_manager.initialize()
|
|
||||||
|
|
||||||
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
|
|
||||||
# 注册消息处理器
|
|
||||||
msg_in = on_message(priority=5)
|
|
||||||
# 注册和bot相关的通知处理器
|
|
||||||
notice_matcher = on_notice(priority=1)
|
|
||||||
# 创建定时任务
|
|
||||||
scheduler = require("nonebot_plugin_apscheduler").scheduler
|
|
||||||
|
|
||||||
|
|
||||||
@driver.on_startup
|
|
||||||
async def start_background_tasks():
|
|
||||||
"""启动后台任务"""
|
|
||||||
# 启动LLM统计
|
|
||||||
llm_stats.start()
|
|
||||||
logger.success("LLM统计功能启动成功")
|
|
||||||
|
|
||||||
# 初始化并启动情绪管理器
|
|
||||||
mood_manager = MoodManager.get_instance()
|
|
||||||
mood_manager.start_mood_update(update_interval=global_config.mood_update_interval)
|
|
||||||
logger.success("情绪管理器启动成功")
|
|
||||||
|
|
||||||
# 只启动表情包管理任务
|
|
||||||
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
|
|
||||||
await bot_schedule.initialize()
|
|
||||||
bot_schedule.print_schedule()
|
|
||||||
|
|
||||||
|
|
||||||
@driver.on_startup
|
|
||||||
async def init_relationships():
|
|
||||||
"""在 NoneBot2 启动时初始化关系管理器"""
|
|
||||||
logger.debug("正在加载用户关系数据...")
|
|
||||||
await relationship_manager.load_all_relationships()
|
|
||||||
asyncio.create_task(relationship_manager._start_relationship_manager())
|
|
||||||
|
|
||||||
|
|
||||||
@driver.on_bot_connect
|
|
||||||
async def _(bot: Bot):
|
|
||||||
"""Bot连接成功时的处理"""
|
|
||||||
global _message_manager_started
|
|
||||||
logger.debug(f"-----------{global_config.BOT_NICKNAME}成功连接!-----------")
|
|
||||||
await willing_manager.ensure_started()
|
|
||||||
|
|
||||||
message_sender.set_bot(bot)
|
|
||||||
logger.success("-----------消息发送器已启动!-----------")
|
|
||||||
|
|
||||||
if not _message_manager_started:
|
|
||||||
asyncio.create_task(message_manager.start_processor())
|
|
||||||
_message_manager_started = True
|
|
||||||
logger.success("-----------消息处理器已启动!-----------")
|
|
||||||
|
|
||||||
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
|
|
||||||
logger.success("-----------开始偷表情包!-----------")
|
|
||||||
asyncio.create_task(chat_manager._initialize())
|
|
||||||
asyncio.create_task(chat_manager._auto_save_task())
|
|
||||||
|
|
||||||
|
|
||||||
@msg_in.handle()
|
|
||||||
async def _(bot: Bot, event: MessageEvent, state: T_State):
|
|
||||||
# 处理合并转发消息
|
|
||||||
if "forward" in event.message:
|
|
||||||
await chat_bot.handle_forward_message(event, bot)
|
|
||||||
else:
|
|
||||||
await chat_bot.handle_message(event, bot)
|
|
||||||
|
|
||||||
|
|
||||||
@notice_matcher.handle()
|
|
||||||
async def _(bot: Bot, event: NoticeEvent, state: T_State):
|
|
||||||
logger.debug(f"收到通知:{event}")
|
|
||||||
await chat_bot.handle_notice(event, bot)
|
|
||||||
|
|
||||||
|
|
||||||
# 添加build_memory定时任务
|
|
||||||
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
|
|
||||||
async def build_memory_task():
|
|
||||||
"""每build_memory_interval秒执行一次记忆构建"""
|
|
||||||
await hippocampus.operation_build_memory()
|
|
||||||
|
|
||||||
|
|
||||||
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
|
|
||||||
async def forget_memory_task():
|
|
||||||
"""每30秒执行一次记忆构建"""
|
|
||||||
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
|
|
||||||
await hippocampus.operation_forget_topic(percentage=global_config.memory_forget_percentage)
|
|
||||||
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
|
|
||||||
|
|
||||||
|
|
||||||
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory")
|
|
||||||
async def merge_memory_task():
|
|
||||||
"""每30秒执行一次记忆构建"""
|
|
||||||
# print("\033[1;32m[记忆整合]\033[0m 开始整合")
|
|
||||||
# await hippocampus.operation_merge_memory(percentage=0.1)
|
|
||||||
# print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
|
|
||||||
|
|
||||||
|
|
||||||
@scheduler.scheduled_job("interval", seconds=30, id="print_mood")
|
|
||||||
async def print_mood_task():
|
|
||||||
"""每30秒打印一次情绪状态"""
|
|
||||||
mood_manager = MoodManager.get_instance()
|
|
||||||
mood_manager.print_mood_status()
|
|
||||||
|
|
||||||
|
|
||||||
@scheduler.scheduled_job("interval", seconds=7200, id="generate_schedule")
|
|
||||||
async def generate_schedule_task():
|
|
||||||
"""每2小时尝试生成一次日程"""
|
|
||||||
logger.debug("尝试生成日程")
|
|
||||||
await bot_schedule.initialize()
|
|
||||||
if not bot_schedule.enable_output:
|
|
||||||
bot_schedule.print_schedule()
|
|
||||||
|
|
||||||
|
|
||||||
@scheduler.scheduled_job("interval", seconds=3600, id="remove_recalled_message")
|
|
||||||
async def remove_recalled_message() -> None:
|
|
||||||
"""删除撤回消息"""
|
|
||||||
try:
|
|
||||||
storage = MessageStorage()
|
|
||||||
await storage.remove_recalled_message(time.time())
|
|
||||||
except Exception:
|
|
||||||
logger.exception("删除撤回消息失败")
|
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
from .bot import chat_bot
|
|
||||||
from .message_cq import MessageRecvCQ
|
|
||||||
from .message_base import UserInfo, GroupInfo
|
|
||||||
from src.common.logger import get_module_logger
|
|
||||||
|
|
||||||
logger = get_module_logger("chat_api")
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
class MessageRequest(BaseModel):
|
|
||||||
message_id: int
|
|
||||||
user_info: Dict[str, Any]
|
|
||||||
raw_message: str
|
|
||||||
group_info: Optional[Dict[str, Any]] = None
|
|
||||||
reply_message: Optional[Dict[str, Any]] = None
|
|
||||||
platform: str = "api"
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/message")
|
|
||||||
async def handle_message(message: MessageRequest):
|
|
||||||
try:
|
|
||||||
user_info = UserInfo(
|
|
||||||
user_id=message.user_info["user_id"],
|
|
||||||
user_nickname=message.user_info["user_nickname"],
|
|
||||||
user_cardname=message.user_info.get("user_cardname"),
|
|
||||||
platform=message.platform,
|
|
||||||
)
|
|
||||||
|
|
||||||
group_info = None
|
|
||||||
if message.group_info:
|
|
||||||
group_info = GroupInfo(
|
|
||||||
group_id=message.group_info["group_id"],
|
|
||||||
group_name=message.group_info.get("group_name"),
|
|
||||||
platform=message.platform,
|
|
||||||
)
|
|
||||||
|
|
||||||
message_cq = MessageRecvCQ(
|
|
||||||
message_id=message.message_id,
|
|
||||||
user_info=user_info,
|
|
||||||
raw_message=message.raw_message,
|
|
||||||
group_info=group_info,
|
|
||||||
reply_message=message.reply_message,
|
|
||||||
platform=message.platform,
|
|
||||||
)
|
|
||||||
|
|
||||||
await chat_bot.message_process(message_cq)
|
|
||||||
return {"status": "success"}
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("API处理消息时出错")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
||||||
@@ -1,16 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from random import random
|
from random import random
|
||||||
from nonebot.adapters.onebot.v11 import (
|
import json
|
||||||
Bot,
|
|
||||||
MessageEvent,
|
|
||||||
PrivateMessageEvent,
|
|
||||||
GroupMessageEvent,
|
|
||||||
NoticeEvent,
|
|
||||||
PokeNotifyEvent,
|
|
||||||
GroupRecallNoticeEvent,
|
|
||||||
FriendRecallNoticeEvent,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..memory_system.memory import hippocampus
|
from ..memory_system.memory import hippocampus
|
||||||
from ..moods.moods import MoodManager # 导入情绪管理器
|
from ..moods.moods import MoodManager # 导入情绪管理器
|
||||||
@@ -18,9 +9,7 @@ from .config import global_config
|
|||||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||||
from .llm_generator import ResponseGenerator
|
from .llm_generator import ResponseGenerator
|
||||||
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||||
from .message_cq import (
|
|
||||||
MessageRecvCQ,
|
|
||||||
)
|
|
||||||
from .chat_stream import chat_manager
|
from .chat_stream import chat_manager
|
||||||
|
|
||||||
from .message_sender import message_manager # 导入新的消息管理器
|
from .message_sender import message_manager # 导入新的消息管理器
|
||||||
@@ -30,7 +19,7 @@ from .utils import is_mentioned_bot_in_message
|
|||||||
from .utils_image import image_path_to_base64
|
from .utils_image import image_path_to_base64
|
||||||
from .utils_user import get_user_nickname, get_user_cardname
|
from .utils_user import get_user_nickname, get_user_cardname
|
||||||
from ..willing.willing_manager import willing_manager # 导入意愿管理器
|
from ..willing.willing_manager import willing_manager # 导入意愿管理器
|
||||||
from .message_base import UserInfo, GroupInfo, Seg
|
from ..message import UserInfo, GroupInfo, Seg
|
||||||
|
|
||||||
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
|
||||||
|
|
||||||
@@ -62,7 +51,7 @@ class ChatBot:
|
|||||||
if not self._started:
|
if not self._started:
|
||||||
self._started = True
|
self._started = True
|
||||||
|
|
||||||
async def message_process(self, message_cq: MessageRecvCQ) -> None:
|
async def message_process(self, message_data: str) -> None:
|
||||||
"""处理转化后的统一格式消息
|
"""处理转化后的统一格式消息
|
||||||
1. 过滤消息
|
1. 过滤消息
|
||||||
2. 记忆激活
|
2. 记忆激活
|
||||||
@@ -71,12 +60,11 @@ class ChatBot:
|
|||||||
5. 更新关系
|
5. 更新关系
|
||||||
6. 更新情绪
|
6. 更新情绪
|
||||||
"""
|
"""
|
||||||
await message_cq.initialize()
|
# message_json = json.loads(message_data)
|
||||||
message_json = message_cq.to_dict()
|
|
||||||
# 哦我嘞个json
|
# 哦我嘞个json
|
||||||
|
|
||||||
# 进入maimbot
|
# 进入maimbot
|
||||||
message = MessageRecv(message_json)
|
message = MessageRecv(message_data)
|
||||||
groupinfo = message.message_info.group_info
|
groupinfo = message.message_info.group_info
|
||||||
userinfo = message.message_info.user_info
|
userinfo = message.message_info.user_info
|
||||||
messageinfo = message.message_info
|
messageinfo = message.message_info
|
||||||
@@ -146,7 +134,7 @@ class ChatBot:
|
|||||||
|
|
||||||
response = None
|
response = None
|
||||||
# 开始组织语言
|
# 开始组织语言
|
||||||
if random() < reply_probability:
|
if random() < reply_probability + 100:
|
||||||
bot_user_info = UserInfo(
|
bot_user_info = UserInfo(
|
||||||
user_id=global_config.BOT_QQ,
|
user_id=global_config.BOT_QQ,
|
||||||
user_nickname=global_config.BOT_NICKNAME,
|
user_nickname=global_config.BOT_NICKNAME,
|
||||||
@@ -278,235 +266,6 @@ class ChatBot:
|
|||||||
# chat_stream=chat
|
# chat_stream=chat
|
||||||
# )
|
# )
|
||||||
|
|
||||||
async def handle_notice(self, event: NoticeEvent, bot: Bot) -> None:
|
|
||||||
"""处理收到的通知"""
|
|
||||||
if isinstance(event, PokeNotifyEvent):
|
|
||||||
# 戳一戳 通知
|
|
||||||
# 不处理其他人的戳戳
|
|
||||||
if not event.is_tome():
|
|
||||||
return
|
|
||||||
|
|
||||||
# 用户屏蔽,不区分私聊/群聊
|
|
||||||
if event.user_id in global_config.ban_user_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 白名单模式
|
|
||||||
if event.group_id:
|
|
||||||
if event.group_id not in global_config.talk_allowed_groups:
|
|
||||||
return
|
|
||||||
|
|
||||||
raw_message = f"[戳了戳]{global_config.BOT_NICKNAME}" # 默认类型
|
|
||||||
if info := event.model_extra["raw_info"]:
|
|
||||||
poke_type = info[2].get("txt", "戳了戳") # 戳戳类型,例如“拍一拍”、“揉一揉”、“捏一捏”
|
|
||||||
custom_poke_message = info[4].get("txt", "") # 自定义戳戳消息,若不存在会为空字符串
|
|
||||||
raw_message = f"[{poke_type}]{global_config.BOT_NICKNAME}{custom_poke_message}"
|
|
||||||
|
|
||||||
raw_message += "(这是一个类似摸摸头的友善行为,而不是恶意行为,请不要作出攻击发言)"
|
|
||||||
|
|
||||||
user_info = UserInfo(
|
|
||||||
user_id=event.user_id,
|
|
||||||
user_nickname=(await bot.get_stranger_info(user_id=event.user_id, no_cache=True))["nickname"],
|
|
||||||
user_cardname=None,
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.group_id:
|
|
||||||
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
|
|
||||||
else:
|
|
||||||
group_info = None
|
|
||||||
|
|
||||||
message_cq = MessageRecvCQ(
|
|
||||||
message_id=0,
|
|
||||||
user_info=user_info,
|
|
||||||
raw_message=str(raw_message),
|
|
||||||
group_info=group_info,
|
|
||||||
reply_message=None,
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.message_process(message_cq)
|
|
||||||
|
|
||||||
elif isinstance(event, GroupRecallNoticeEvent) or isinstance(event, FriendRecallNoticeEvent):
|
|
||||||
user_info = UserInfo(
|
|
||||||
user_id=event.user_id,
|
|
||||||
user_nickname=get_user_nickname(event.user_id) or None,
|
|
||||||
user_cardname=get_user_cardname(event.user_id) or None,
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(event, GroupRecallNoticeEvent):
|
|
||||||
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
|
|
||||||
else:
|
|
||||||
group_info = None
|
|
||||||
|
|
||||||
chat = await chat_manager.get_or_create_stream(
|
|
||||||
platform=user_info.platform, user_info=user_info, group_info=group_info
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.storage.store_recalled_message(event.message_id, time.time(), chat)
|
|
||||||
|
|
||||||
async def handle_message(self, event: MessageEvent, bot: Bot) -> None:
|
|
||||||
"""处理收到的消息"""
|
|
||||||
|
|
||||||
self.bot = bot # 更新 bot 实例
|
|
||||||
|
|
||||||
# 用户屏蔽,不区分私聊/群聊
|
|
||||||
if event.user_id in global_config.ban_user_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
event.reply
|
|
||||||
and hasattr(event.reply, "sender")
|
|
||||||
and hasattr(event.reply.sender, "user_id")
|
|
||||||
and event.reply.sender.user_id in global_config.ban_user_id
|
|
||||||
):
|
|
||||||
logger.debug(f"跳过处理回复来自被ban用户 {event.reply.sender.user_id} 的消息")
|
|
||||||
return
|
|
||||||
# 处理私聊消息
|
|
||||||
if isinstance(event, PrivateMessageEvent):
|
|
||||||
if not global_config.enable_friend_chat: # 私聊过滤
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
user_info = UserInfo(
|
|
||||||
user_id=event.user_id,
|
|
||||||
user_nickname=(await bot.get_stranger_info(user_id=event.user_id, no_cache=True))["nickname"],
|
|
||||||
user_cardname=None,
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取陌生人信息失败: {e}")
|
|
||||||
return
|
|
||||||
logger.debug(user_info)
|
|
||||||
|
|
||||||
# group_info = GroupInfo(group_id=0, group_name="私聊", platform="qq")
|
|
||||||
group_info = None
|
|
||||||
|
|
||||||
# 处理群聊消息
|
|
||||||
else:
|
|
||||||
# 白名单设定由nontbot侧完成
|
|
||||||
if event.group_id:
|
|
||||||
if event.group_id not in global_config.talk_allowed_groups:
|
|
||||||
return
|
|
||||||
|
|
||||||
user_info = UserInfo(
|
|
||||||
user_id=event.user_id,
|
|
||||||
user_nickname=event.sender.nickname,
|
|
||||||
user_cardname=event.sender.card or None,
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
|
|
||||||
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
|
|
||||||
|
|
||||||
# group_info = await bot.get_group_info(group_id=event.group_id)
|
|
||||||
# sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
|
||||||
|
|
||||||
message_cq = MessageRecvCQ(
|
|
||||||
message_id=event.message_id,
|
|
||||||
user_info=user_info,
|
|
||||||
raw_message=str(event.original_message),
|
|
||||||
group_info=group_info,
|
|
||||||
reply_message=event.reply,
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.message_process(message_cq)
|
|
||||||
|
|
||||||
async def handle_forward_message(self, event: MessageEvent, bot: Bot) -> None:
|
|
||||||
"""专用于处理合并转发的消息处理器"""
|
|
||||||
|
|
||||||
# 用户屏蔽,不区分私聊/群聊
|
|
||||||
if event.user_id in global_config.ban_user_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(event, GroupMessageEvent):
|
|
||||||
if event.group_id:
|
|
||||||
if event.group_id not in global_config.talk_allowed_groups:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 获取合并转发消息的详细信息
|
|
||||||
forward_info = await bot.get_forward_msg(message_id=event.message_id)
|
|
||||||
messages = forward_info["messages"]
|
|
||||||
|
|
||||||
# 构建合并转发消息的文本表示
|
|
||||||
processed_messages = []
|
|
||||||
for node in messages:
|
|
||||||
# 提取发送者昵称
|
|
||||||
nickname = node["sender"].get("nickname", "未知用户")
|
|
||||||
|
|
||||||
# 递归处理消息内容
|
|
||||||
message_content = await self.process_message_segments(node["message"], layer=0)
|
|
||||||
|
|
||||||
# 拼接为【昵称】+ 内容
|
|
||||||
processed_messages.append(f"【{nickname}】{message_content}")
|
|
||||||
|
|
||||||
# 组合所有消息
|
|
||||||
combined_message = "\n".join(processed_messages)
|
|
||||||
combined_message = f"合并转发消息内容:\n{combined_message}"
|
|
||||||
|
|
||||||
# 构建用户信息(使用转发消息的发送者)
|
|
||||||
user_info = UserInfo(
|
|
||||||
user_id=event.user_id,
|
|
||||||
user_nickname=event.sender.nickname,
|
|
||||||
user_cardname=event.sender.card if hasattr(event.sender, "card") else None,
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建群聊信息(如果是群聊)
|
|
||||||
group_info = None
|
|
||||||
if isinstance(event, GroupMessageEvent):
|
|
||||||
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
|
|
||||||
|
|
||||||
# 创建消息对象
|
|
||||||
message_cq = MessageRecvCQ(
|
|
||||||
message_id=event.message_id,
|
|
||||||
user_info=user_info,
|
|
||||||
raw_message=combined_message,
|
|
||||||
group_info=group_info,
|
|
||||||
reply_message=event.reply,
|
|
||||||
platform="qq",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 进入标准消息处理流程
|
|
||||||
await self.message_process(message_cq)
|
|
||||||
|
|
||||||
async def process_message_segments(self, segments: list, layer: int) -> str:
|
|
||||||
"""递归处理消息段"""
|
|
||||||
parts = []
|
|
||||||
for seg in segments:
|
|
||||||
part = await self.process_segment(seg, layer + 1)
|
|
||||||
parts.append(part)
|
|
||||||
return "".join(parts)
|
|
||||||
|
|
||||||
async def process_segment(self, seg: dict, layer: int) -> str:
|
|
||||||
"""处理单个消息段"""
|
|
||||||
seg_type = seg["type"]
|
|
||||||
if layer > 3:
|
|
||||||
# 防止有那种100层转发消息炸飞麦麦
|
|
||||||
return "【转发消息】"
|
|
||||||
if seg_type == "text":
|
|
||||||
return seg["data"]["text"]
|
|
||||||
elif seg_type == "image":
|
|
||||||
return "[图片]"
|
|
||||||
elif seg_type == "face":
|
|
||||||
return "[表情]"
|
|
||||||
elif seg_type == "at":
|
|
||||||
return f"@{seg['data'].get('qq', '未知用户')}"
|
|
||||||
elif seg_type == "forward":
|
|
||||||
# 递归处理嵌套的合并转发消息
|
|
||||||
nested_nodes = seg["data"].get("content", [])
|
|
||||||
nested_messages = []
|
|
||||||
nested_messages.append("合并转发消息内容:")
|
|
||||||
for node in nested_nodes:
|
|
||||||
nickname = node["sender"].get("nickname", "未知用户")
|
|
||||||
content = await self.process_message_segments(node["message"], layer=layer)
|
|
||||||
# nested_messages.append('-' * layer)
|
|
||||||
nested_messages.append(f"{'--' * layer}【{nickname}】{content}")
|
|
||||||
# nested_messages.append(f"{'--' * layer}合并转发第【{layer}】层结束")
|
|
||||||
return "\n".join(nested_messages)
|
|
||||||
else:
|
|
||||||
return f"[{seg_type}]"
|
|
||||||
|
|
||||||
|
|
||||||
# 创建全局ChatBot实例
|
# 创建全局ChatBot实例
|
||||||
chat_bot = ChatBot()
|
chat_bot = ChatBot()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from .message_base import GroupInfo, UserInfo
|
from ..message.message_base import GroupInfo, UserInfo
|
||||||
|
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
|
|||||||
@@ -1,385 +0,0 @@
|
|||||||
import base64
|
|
||||||
import html
|
|
||||||
import asyncio
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
import ssl
|
|
||||||
import os
|
|
||||||
import aiohttp
|
|
||||||
from src.common.logger import get_module_logger
|
|
||||||
from nonebot import get_driver
|
|
||||||
|
|
||||||
from ..models.utils_model import LLM_request
|
|
||||||
from .config import global_config
|
|
||||||
from .mapper import emojimapper
|
|
||||||
from .message_base import Seg
|
|
||||||
from .utils_user import get_user_nickname, get_groupname
|
|
||||||
from .message_base import GroupInfo, UserInfo
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
# 创建SSL上下文
|
|
||||||
ssl_context = ssl.create_default_context()
|
|
||||||
ssl_context.set_ciphers("AES128-GCM-SHA256")
|
|
||||||
|
|
||||||
logger = get_module_logger("cq_code")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CQCode:
|
|
||||||
"""
|
|
||||||
CQ码数据类,用于存储和处理CQ码
|
|
||||||
|
|
||||||
属性:
|
|
||||||
type: CQ码类型(如'image', 'at', 'face'等)
|
|
||||||
params: CQ码的参数字典
|
|
||||||
raw_code: 原始CQ码字符串
|
|
||||||
translated_segments: 经过处理后的Seg对象列表
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: str
|
|
||||||
params: Dict[str, str]
|
|
||||||
group_info: Optional[GroupInfo] = None
|
|
||||||
user_info: Optional[UserInfo] = None
|
|
||||||
translated_segments: Optional[Union[Seg, List[Seg]]] = None
|
|
||||||
reply_message: Dict = None # 存储回复消息
|
|
||||||
image_base64: Optional[str] = None
|
|
||||||
_llm: Optional[LLM_request] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""初始化LLM实例"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def translate(self):
|
|
||||||
"""根据CQ码类型进行相应的翻译处理,转换为Seg对象"""
|
|
||||||
if self.type == "text":
|
|
||||||
self.translated_segments = Seg(type="text", data=self.params.get("text", ""))
|
|
||||||
elif self.type == "image":
|
|
||||||
base64_data = await self.translate_image()
|
|
||||||
if base64_data:
|
|
||||||
if self.params.get("sub_type") == "0":
|
|
||||||
self.translated_segments = Seg(type="image", data=base64_data)
|
|
||||||
else:
|
|
||||||
self.translated_segments = Seg(type="emoji", data=base64_data)
|
|
||||||
else:
|
|
||||||
self.translated_segments = Seg(type="text", data="[图片]")
|
|
||||||
elif self.type == "at":
|
|
||||||
if self.params.get("qq") == "all":
|
|
||||||
self.translated_segments = Seg(type="text", data="@[全体成员]")
|
|
||||||
else:
|
|
||||||
user_nickname = get_user_nickname(self.params.get("qq", ""))
|
|
||||||
self.translated_segments = Seg(type="text", data=f"[@{user_nickname or '某人'}]")
|
|
||||||
elif self.type == "reply":
|
|
||||||
reply_segments = await self.translate_reply()
|
|
||||||
if reply_segments:
|
|
||||||
self.translated_segments = Seg(type="seglist", data=reply_segments)
|
|
||||||
else:
|
|
||||||
self.translated_segments = Seg(type="text", data="[回复某人消息]")
|
|
||||||
elif self.type == "face":
|
|
||||||
face_id = self.params.get("id", "")
|
|
||||||
self.translated_segments = Seg(type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]")
|
|
||||||
elif self.type == "forward":
|
|
||||||
forward_segments = await self.translate_forward()
|
|
||||||
if forward_segments:
|
|
||||||
self.translated_segments = Seg(type="seglist", data=forward_segments)
|
|
||||||
else:
|
|
||||||
self.translated_segments = Seg(type="text", data="[转发消息]")
|
|
||||||
else:
|
|
||||||
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
|
|
||||||
|
|
||||||
async def get_img(self) -> Optional[str]:
|
|
||||||
"""异步获取图片并转换为base64"""
|
|
||||||
headers = {
|
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) "
|
|
||||||
"Chrome/50.0.2661.87 Safari/537.36",
|
|
||||||
"Accept": "text/html, application/xhtml xml, */*",
|
|
||||||
"Accept-Encoding": "gbk, GB2312",
|
|
||||||
"Accept-Language": "zh-cn",
|
|
||||||
"Content-Type": "application/x-www-form-urlencoded",
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
}
|
|
||||||
|
|
||||||
url = html.unescape(self.params["url"])
|
|
||||||
if not url.startswith(("http://", "https://")):
|
|
||||||
return None
|
|
||||||
|
|
||||||
max_retries = 3
|
|
||||||
for retry in range(max_retries):
|
|
||||||
try:
|
|
||||||
logger.debug(f"获取图片中: {url}")
|
|
||||||
# 设置SSL上下文和创建连接器
|
|
||||||
conn = aiohttp.TCPConnector(ssl=ssl_context)
|
|
||||||
async with aiohttp.ClientSession(connector=conn) as session:
|
|
||||||
async with session.get(
|
|
||||||
url,
|
|
||||||
headers=headers,
|
|
||||||
timeout=aiohttp.ClientTimeout(total=15),
|
|
||||||
allow_redirects=True,
|
|
||||||
) as response:
|
|
||||||
# 腾讯服务器特殊状态码处理
|
|
||||||
if response.status == 400 and "multimedia.nt.qq.com.cn" in url:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if response.status != 200:
|
|
||||||
raise aiohttp.ClientError(f"HTTP {response.status}")
|
|
||||||
|
|
||||||
# 验证内容类型
|
|
||||||
content_type = response.headers.get("Content-Type", "")
|
|
||||||
if not content_type.startswith("image/"):
|
|
||||||
raise ValueError(f"非图片内容类型: {content_type}")
|
|
||||||
|
|
||||||
# 读取响应内容
|
|
||||||
content = await response.read()
|
|
||||||
logger.debug(f"获取图片成功: {url}")
|
|
||||||
|
|
||||||
# 转换为Base64
|
|
||||||
image_base64 = base64.b64encode(content).decode("utf-8")
|
|
||||||
self.image_base64 = image_base64
|
|
||||||
return image_base64
|
|
||||||
|
|
||||||
except (aiohttp.ClientError, ValueError) as e:
|
|
||||||
if retry == max_retries - 1:
|
|
||||||
logger.error(f"最终请求失败: {str(e)}")
|
|
||||||
await asyncio.sleep(1.5**retry) # 指数退避
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"获取图片时发生未知错误: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def translate_image(self) -> Optional[str]:
|
|
||||||
"""处理图片类型的CQ码,返回base64字符串"""
|
|
||||||
if "url" not in self.params:
|
|
||||||
return None
|
|
||||||
return await self.get_img()
|
|
||||||
|
|
||||||
async def translate_forward(self) -> Optional[List[Seg]]:
|
|
||||||
"""处理转发消息,返回Seg列表"""
|
|
||||||
try:
|
|
||||||
if "content" not in self.params:
|
|
||||||
return None
|
|
||||||
|
|
||||||
content = self.unescape(self.params["content"])
|
|
||||||
import ast
|
|
||||||
|
|
||||||
try:
|
|
||||||
messages = ast.literal_eval(content)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"解析转发消息内容失败: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
formatted_segments = []
|
|
||||||
for msg in messages:
|
|
||||||
sender = msg.get("sender", {})
|
|
||||||
nickname = sender.get("card") or sender.get("nickname", "未知用户")
|
|
||||||
raw_message = msg.get("raw_message", "")
|
|
||||||
message_array = msg.get("message", [])
|
|
||||||
|
|
||||||
if message_array and isinstance(message_array, list):
|
|
||||||
for message_part in message_array:
|
|
||||||
if message_part.get("type") == "forward":
|
|
||||||
content_seg = Seg(type="text", data="[转发消息]")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if raw_message:
|
|
||||||
from .message_cq import MessageRecvCQ
|
|
||||||
|
|
||||||
user_info = UserInfo(
|
|
||||||
platform="qq",
|
|
||||||
user_id=msg.get("user_id", 0),
|
|
||||||
user_nickname=nickname,
|
|
||||||
)
|
|
||||||
group_info = GroupInfo(
|
|
||||||
platform="qq",
|
|
||||||
group_id=msg.get("group_id", 0),
|
|
||||||
group_name=get_groupname(msg.get("group_id", 0)),
|
|
||||||
)
|
|
||||||
|
|
||||||
message_obj = MessageRecvCQ(
|
|
||||||
message_id=msg.get("message_id", 0),
|
|
||||||
user_info=user_info,
|
|
||||||
raw_message=raw_message,
|
|
||||||
plain_text=raw_message,
|
|
||||||
group_info=group_info,
|
|
||||||
)
|
|
||||||
await message_obj.initialize()
|
|
||||||
content_seg = Seg(type="seglist", data=[message_obj.message_segment])
|
|
||||||
else:
|
|
||||||
content_seg = Seg(type="text", data="[空消息]")
|
|
||||||
else:
|
|
||||||
if raw_message:
|
|
||||||
from .message_cq import MessageRecvCQ
|
|
||||||
|
|
||||||
user_info = UserInfo(
|
|
||||||
platform="qq",
|
|
||||||
user_id=msg.get("user_id", 0),
|
|
||||||
user_nickname=nickname,
|
|
||||||
)
|
|
||||||
group_info = GroupInfo(
|
|
||||||
platform="qq",
|
|
||||||
group_id=msg.get("group_id", 0),
|
|
||||||
group_name=get_groupname(msg.get("group_id", 0)),
|
|
||||||
)
|
|
||||||
message_obj = MessageRecvCQ(
|
|
||||||
message_id=msg.get("message_id", 0),
|
|
||||||
user_info=user_info,
|
|
||||||
raw_message=raw_message,
|
|
||||||
plain_text=raw_message,
|
|
||||||
group_info=group_info,
|
|
||||||
)
|
|
||||||
await message_obj.initialize()
|
|
||||||
content_seg = Seg(type="seglist", data=[message_obj.message_segment])
|
|
||||||
else:
|
|
||||||
content_seg = Seg(type="text", data="[空消息]")
|
|
||||||
|
|
||||||
formatted_segments.append(Seg(type="text", data=f"{nickname}: "))
|
|
||||||
formatted_segments.append(content_seg)
|
|
||||||
formatted_segments.append(Seg(type="text", data="\n"))
|
|
||||||
|
|
||||||
return formatted_segments
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理转发消息失败: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def translate_reply(self) -> Optional[List[Seg]]:
|
|
||||||
"""处理回复类型的CQ码,返回Seg列表"""
|
|
||||||
from .message_cq import MessageRecvCQ
|
|
||||||
|
|
||||||
if self.reply_message is None:
|
|
||||||
return None
|
|
||||||
if hasattr(self.reply_message, "group_id"):
|
|
||||||
group_info = GroupInfo(platform="qq", group_id=self.reply_message.group_id, group_name="")
|
|
||||||
else:
|
|
||||||
group_info = None
|
|
||||||
|
|
||||||
if self.reply_message.sender.user_id:
|
|
||||||
message_obj = MessageRecvCQ(
|
|
||||||
user_info=UserInfo(
|
|
||||||
user_id=self.reply_message.sender.user_id, user_nickname=self.reply_message.sender.nickname
|
|
||||||
),
|
|
||||||
message_id=self.reply_message.message_id,
|
|
||||||
raw_message=str(self.reply_message.message),
|
|
||||||
group_info=group_info,
|
|
||||||
)
|
|
||||||
await message_obj.initialize()
|
|
||||||
|
|
||||||
segments = []
|
|
||||||
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
|
|
||||||
segments.append(Seg(type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "))
|
|
||||||
else:
|
|
||||||
segments.append(
|
|
||||||
Seg(
|
|
||||||
type="text",
|
|
||||||
data=f"[回复 {self.reply_message.sender.nickname} 的消息: ",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
segments.append(Seg(type="seglist", data=[message_obj.message_segment]))
|
|
||||||
segments.append(Seg(type="text", data="]"))
|
|
||||||
return segments
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def unescape(text: str) -> str:
|
|
||||||
"""反转义CQ码中的特殊字符"""
|
|
||||||
return text.replace(",", ",").replace("[", "[").replace("]", "]").replace("&", "&")
|
|
||||||
|
|
||||||
|
|
||||||
class CQCode_tool:
|
|
||||||
@staticmethod
|
|
||||||
def cq_from_dict_to_class(cq_code: Dict, msg, reply: Optional[Dict] = None) -> CQCode:
|
|
||||||
"""
|
|
||||||
将CQ码字典转换为CQCode对象
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cq_code: CQ码字典
|
|
||||||
msg: MessageCQ对象
|
|
||||||
reply: 回复消息的字典(可选)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CQCode对象
|
|
||||||
"""
|
|
||||||
# 处理字典形式的CQ码
|
|
||||||
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
|
|
||||||
cq_type = cq_code.get("type", "text")
|
|
||||||
params = {}
|
|
||||||
if cq_type == "text":
|
|
||||||
params["text"] = cq_code.get("data", {}).get("text", "")
|
|
||||||
else:
|
|
||||||
params = cq_code.get("data", {})
|
|
||||||
|
|
||||||
instance = CQCode(
|
|
||||||
type=cq_type,
|
|
||||||
params=params,
|
|
||||||
group_info=msg.message_info.group_info,
|
|
||||||
user_info=msg.message_info.user_info,
|
|
||||||
reply_message=reply,
|
|
||||||
)
|
|
||||||
|
|
||||||
return instance
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_reply_cq(message_id: int) -> str:
|
|
||||||
"""
|
|
||||||
创建回复CQ码
|
|
||||||
Args:
|
|
||||||
message_id: 回复的消息ID
|
|
||||||
Returns:
|
|
||||||
回复CQ码字符串
|
|
||||||
"""
|
|
||||||
return f"[CQ:reply,id={message_id}]"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_emoji_cq(file_path: str) -> str:
|
|
||||||
"""
|
|
||||||
创建表情包CQ码
|
|
||||||
Args:
|
|
||||||
file_path: 本地表情包文件路径
|
|
||||||
Returns:
|
|
||||||
表情包CQ码字符串
|
|
||||||
"""
|
|
||||||
# 确保使用绝对路径
|
|
||||||
abs_path = os.path.abspath(file_path)
|
|
||||||
# 转义特殊字符
|
|
||||||
escaped_path = abs_path.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
|
|
||||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
|
||||||
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_emoji_cq_base64(base64_data: str) -> str:
|
|
||||||
"""
|
|
||||||
创建表情包CQ码
|
|
||||||
Args:
|
|
||||||
base64_data: base64编码的表情包数据
|
|
||||||
Returns:
|
|
||||||
表情包CQ码字符串
|
|
||||||
"""
|
|
||||||
# 转义base64数据
|
|
||||||
escaped_base64 = (
|
|
||||||
base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
|
|
||||||
)
|
|
||||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
|
||||||
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_image_cq_base64(base64_data: str) -> str:
|
|
||||||
"""
|
|
||||||
创建表情包CQ码
|
|
||||||
Args:
|
|
||||||
base64_data: base64编码的表情包数据
|
|
||||||
Returns:
|
|
||||||
表情包CQ码字符串
|
|
||||||
"""
|
|
||||||
# 转义base64数据
|
|
||||||
escaped_base64 = (
|
|
||||||
base64_data.replace("&", "&").replace("[", "[").replace("]", "]").replace(",", ",")
|
|
||||||
)
|
|
||||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
|
||||||
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
|
|
||||||
|
|
||||||
|
|
||||||
cq_code_tool = CQCode_tool()
|
|
||||||
@@ -9,8 +9,6 @@ from typing import Optional, Tuple
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
|
|
||||||
from nonebot import get_driver
|
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
from ..chat.utils import get_embedding
|
from ..chat.utils import get_embedding
|
||||||
@@ -21,8 +19,6 @@ from src.common.logger import get_module_logger
|
|||||||
logger = get_module_logger("emoji")
|
logger = get_module_logger("emoji")
|
||||||
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
image_manager = ImageManager()
|
image_manager = ImageManager()
|
||||||
|
|
||||||
|
|
||||||
@@ -118,9 +114,11 @@ class EmojiManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取所有表情包
|
# 获取所有表情包
|
||||||
all_emojis = [e for e in
|
all_emojis = [
|
||||||
db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1, "blacklist": 1})
|
e
|
||||||
if 'blacklist' not in e]
|
for e in db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1, "blacklist": 1})
|
||||||
|
if "blacklist" not in e
|
||||||
|
]
|
||||||
|
|
||||||
if not all_emojis:
|
if not all_emojis:
|
||||||
logger.warning("数据库中没有任何表情包")
|
logger.warning("数据库中没有任何表情包")
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import random
|
|||||||
import time
|
import time
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from nonebot import get_driver
|
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
@@ -21,9 +20,6 @@ llm_config = LogConfig(
|
|||||||
|
|
||||||
logger = get_module_logger("llm_generator", config=llm_config)
|
logger = get_module_logger("llm_generator", config=llm_config)
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseGenerator:
|
class ResponseGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import urllib3
|
|||||||
|
|
||||||
from .utils_image import image_manager
|
from .utils_image import image_manager
|
||||||
|
|
||||||
from .message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
|
from ..message.message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
@@ -75,19 +75,6 @@ class MessageRecv(Message):
|
|||||||
"""
|
"""
|
||||||
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||||
|
|
||||||
message_segment = message_dict.get("message_segment", {})
|
|
||||||
|
|
||||||
if message_segment.get("data", "") == "[json]":
|
|
||||||
# 提取json消息中的展示信息
|
|
||||||
pattern = r"\[CQ:json,data=(?P<json_data>.+?)\]"
|
|
||||||
match = re.search(pattern, message_dict.get("raw_message", ""))
|
|
||||||
raw_json = html.unescape(match.group("json_data"))
|
|
||||||
try:
|
|
||||||
json_message = json.loads(raw_json)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
json_message = {}
|
|
||||||
message_segment["data"] = json_message.get("prompt", "")
|
|
||||||
|
|
||||||
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||||
self.raw_message = message_dict.get("raw_message")
|
self.raw_message = message_dict.get("raw_message")
|
||||||
|
|
||||||
|
|||||||
@@ -1,170 +0,0 @@
|
|||||||
import time
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import urllib3
|
|
||||||
|
|
||||||
from .cq_code import cq_code_tool
|
|
||||||
from .utils_cq import parse_cq_code
|
|
||||||
from .utils_user import get_groupname
|
|
||||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
|
||||||
|
|
||||||
# 禁用SSL警告
|
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
||||||
|
|
||||||
# 这个类是消息数据类,用于存储和管理消息数据。
|
|
||||||
# 它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
|
|
||||||
# 它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MessageCQ(MessageBase):
|
|
||||||
"""QQ消息基类,继承自MessageBase
|
|
||||||
|
|
||||||
最小必要参数:
|
|
||||||
- message_id: 消息ID
|
|
||||||
- user_id: 发送者/接收者ID
|
|
||||||
- platform: 平台标识(默认为"qq")
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, message_id: int, user_info: UserInfo, group_info: Optional[GroupInfo] = None, platform: str = "qq"
|
|
||||||
):
|
|
||||||
# 构造基础消息信息
|
|
||||||
message_info = BaseMessageInfo(
|
|
||||||
platform=platform, message_id=message_id, time=int(time.time()), group_info=group_info, user_info=user_info
|
|
||||||
)
|
|
||||||
# 调用父类初始化,message_segment 由子类设置
|
|
||||||
super().__init__(message_info=message_info, message_segment=None, raw_message=None)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MessageRecvCQ(MessageCQ):
|
|
||||||
"""QQ接收消息类,用于解析raw_message到Seg对象"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message_id: int,
|
|
||||||
user_info: UserInfo,
|
|
||||||
raw_message: str,
|
|
||||||
group_info: Optional[GroupInfo] = None,
|
|
||||||
platform: str = "qq",
|
|
||||||
reply_message: Optional[Dict] = None,
|
|
||||||
):
|
|
||||||
# 调用父类初始化
|
|
||||||
super().__init__(message_id, user_info, group_info, platform)
|
|
||||||
|
|
||||||
# 私聊消息不携带group_info
|
|
||||||
if group_info is None:
|
|
||||||
pass
|
|
||||||
elif group_info.group_name is None:
|
|
||||||
group_info.group_name = get_groupname(group_info.group_id)
|
|
||||||
|
|
||||||
# 解析消息段
|
|
||||||
self.message_segment = None # 初始化为None
|
|
||||||
self.raw_message = raw_message
|
|
||||||
# 异步初始化在外部完成
|
|
||||||
|
|
||||||
# 添加对reply的解析
|
|
||||||
self.reply_message = reply_message
|
|
||||||
|
|
||||||
async def initialize(self):
|
|
||||||
"""异步初始化方法"""
|
|
||||||
self.message_segment = await self._parse_message(self.raw_message, self.reply_message)
|
|
||||||
|
|
||||||
async def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
|
|
||||||
"""异步解析消息内容为Seg对象"""
|
|
||||||
cq_code_dict_list = []
|
|
||||||
segments = []
|
|
||||||
|
|
||||||
start = 0
|
|
||||||
while True:
|
|
||||||
cq_start = message.find("[CQ:", start)
|
|
||||||
if cq_start == -1:
|
|
||||||
if start < len(message):
|
|
||||||
text = message[start:].strip()
|
|
||||||
if text:
|
|
||||||
cq_code_dict_list.append(parse_cq_code(text))
|
|
||||||
break
|
|
||||||
|
|
||||||
if cq_start > start:
|
|
||||||
text = message[start:cq_start].strip()
|
|
||||||
if text:
|
|
||||||
cq_code_dict_list.append(parse_cq_code(text))
|
|
||||||
|
|
||||||
cq_end = message.find("]", cq_start)
|
|
||||||
if cq_end == -1:
|
|
||||||
text = message[cq_start:].strip()
|
|
||||||
if text:
|
|
||||||
cq_code_dict_list.append(parse_cq_code(text))
|
|
||||||
break
|
|
||||||
|
|
||||||
cq_code = message[cq_start : cq_end + 1]
|
|
||||||
cq_code_dict_list.append(parse_cq_code(cq_code))
|
|
||||||
start = cq_end + 1
|
|
||||||
|
|
||||||
# 转换CQ码为Seg对象
|
|
||||||
for code_item in cq_code_dict_list:
|
|
||||||
cq_code_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message)
|
|
||||||
await cq_code_obj.translate() # 异步调用translate
|
|
||||||
if cq_code_obj.translated_segments:
|
|
||||||
segments.append(cq_code_obj.translated_segments)
|
|
||||||
|
|
||||||
# 如果只有一个segment,直接返回
|
|
||||||
if len(segments) == 1:
|
|
||||||
return segments[0]
|
|
||||||
|
|
||||||
# 否则返回seglist类型的Seg
|
|
||||||
return Seg(type="seglist", data=segments)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
|
||||||
"""转换为字典格式,包含所有必要信息"""
|
|
||||||
base_dict = super().to_dict()
|
|
||||||
return base_dict
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MessageSendCQ(MessageCQ):
|
|
||||||
"""QQ发送消息类,用于将Seg对象转换为raw_message"""
|
|
||||||
|
|
||||||
def __init__(self, data: Dict):
|
|
||||||
# 调用父类初始化
|
|
||||||
message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
|
|
||||||
message_segment = Seg.from_dict(data.get("message_segment", {}))
|
|
||||||
super().__init__(
|
|
||||||
message_info.message_id,
|
|
||||||
message_info.user_info,
|
|
||||||
message_info.group_info if message_info.group_info else None,
|
|
||||||
message_info.platform,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.message_segment = message_segment
|
|
||||||
self.raw_message = self._generate_raw_message()
|
|
||||||
|
|
||||||
def _generate_raw_message(self) -> str:
|
|
||||||
"""将Seg对象转换为raw_message"""
|
|
||||||
segments = []
|
|
||||||
|
|
||||||
# 处理消息段
|
|
||||||
if self.message_segment.type == "seglist":
|
|
||||||
for seg in self.message_segment.data:
|
|
||||||
segments.append(self._seg_to_cq_code(seg))
|
|
||||||
else:
|
|
||||||
segments.append(self._seg_to_cq_code(self.message_segment))
|
|
||||||
|
|
||||||
return "".join(segments)
|
|
||||||
|
|
||||||
def _seg_to_cq_code(self, seg: Seg) -> str:
|
|
||||||
"""将单个Seg对象转换为CQ码字符串"""
|
|
||||||
if seg.type == "text":
|
|
||||||
return str(seg.data)
|
|
||||||
elif seg.type == "image":
|
|
||||||
return cq_code_tool.create_image_cq_base64(seg.data)
|
|
||||||
elif seg.type == "emoji":
|
|
||||||
return cq_code_tool.create_emoji_cq_base64(seg.data)
|
|
||||||
elif seg.type == "at":
|
|
||||||
return f"[CQ:at,qq={seg.data}]"
|
|
||||||
elif seg.type == "reply":
|
|
||||||
return cq_code_tool.create_reply_cq(int(seg.data))
|
|
||||||
else:
|
|
||||||
return f"[{seg.data}]"
|
|
||||||
@@ -3,9 +3,8 @@ import time
|
|||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from nonebot.adapters.onebot.v11 import Bot
|
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from .message_cq import MessageSendCQ
|
from ..message.api import global_api
|
||||||
from .message import MessageSending, MessageThinking, MessageSet
|
from .message import MessageSending, MessageThinking, MessageSet
|
||||||
|
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
@@ -32,9 +31,9 @@ class Message_Sender:
|
|||||||
self.last_send_time = 0
|
self.last_send_time = 0
|
||||||
self._current_bot = None
|
self._current_bot = None
|
||||||
|
|
||||||
def set_bot(self, bot: Bot):
|
def set_bot(self, bot):
|
||||||
"""设置当前bot实例"""
|
"""设置当前bot实例"""
|
||||||
self._current_bot = bot
|
pass
|
||||||
|
|
||||||
def get_recalled_messages(self, stream_id: str) -> list:
|
def get_recalled_messages(self, stream_id: str) -> list:
|
||||||
"""获取所有撤回的消息"""
|
"""获取所有撤回的消息"""
|
||||||
@@ -60,31 +59,14 @@ class Message_Sender:
|
|||||||
break
|
break
|
||||||
if not is_recalled:
|
if not is_recalled:
|
||||||
message_json = message.to_dict()
|
message_json = message.to_dict()
|
||||||
message_send = MessageSendCQ(data=message_json)
|
|
||||||
message_preview = truncate_message(message.processed_plain_text)
|
message_preview = truncate_message(message.processed_plain_text)
|
||||||
if message_send.message_info.group_info and message_send.message_info.group_info.group_id:
|
|
||||||
try:
|
try:
|
||||||
await self._current_bot.send_group_msg(
|
result = await global_api.send_message("http://127.0.0.1:18002/api/message", message_json)
|
||||||
group_id=message.message_info.group_info.group_id,
|
if result["status"] == "success":
|
||||||
message=message_send.raw_message,
|
|
||||||
auto_escape=False,
|
|
||||||
)
|
|
||||||
logger.success(f"发送消息“{message_preview}”成功")
|
logger.success(f"发送消息“{message_preview}”成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[调试] 发生错误 {e}")
|
logger.error(f"发送消息“{message_preview}”失败: {str(e)}")
|
||||||
logger.error(f"[调试] 发送消息“{message_preview}”失败")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
logger.debug(message.message_info.user_info)
|
|
||||||
await self._current_bot.send_private_msg(
|
|
||||||
user_id=message.sender_info.user_id,
|
|
||||||
message=message_send.raw_message,
|
|
||||||
auto_escape=False,
|
|
||||||
)
|
|
||||||
logger.success(f"发送消息“{message_preview}”成功")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[调试] 发生错误 {e}")
|
|
||||||
logger.error(f"[调试] 发送消息“{message_preview}”失败")
|
|
||||||
|
|
||||||
|
|
||||||
class MessageContainer:
|
class MessageContainer:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from .message_base import UserInfo
|
from ..message.message_base import UserInfo
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
import math
|
import math
|
||||||
from bson.decimal128 import Decimal128
|
from bson.decimal128 import Decimal128
|
||||||
@@ -122,11 +122,15 @@ class RelationshipManager:
|
|||||||
relationship.relationship_value = float(relationship.relationship_value.to_decimal())
|
relationship.relationship_value = float(relationship.relationship_value.to_decimal())
|
||||||
else:
|
else:
|
||||||
relationship.relationship_value = float(relationship.relationship_value)
|
relationship.relationship_value = float(relationship.relationship_value)
|
||||||
logger.info(f"[关系管理] 用户 {user_id}({platform}) 的关系值已转换为double类型: {relationship.relationship_value}")
|
logger.info(
|
||||||
|
f"[关系管理] 用户 {user_id}({platform}) 的关系值已转换为double类型: {relationship.relationship_value}"
|
||||||
|
)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
# 如果不能解析/强转则将relationship.relationship_value设置为double类型的0
|
# 如果不能解析/强转则将relationship.relationship_value设置为double类型的0
|
||||||
relationship.relationship_value = 0.0
|
relationship.relationship_value = 0.0
|
||||||
logger.warning(f"[关系管理] 用户 {user_id}({platform}) 的关系值无法转换为double类型,已设置为0")
|
logger.warning(
|
||||||
|
f"[关系管理] 用户 {user_id}({platform}) 的关系值无法转换为double类型,已设置为0"
|
||||||
|
)
|
||||||
relationship.relationship_value += value
|
relationship.relationship_value += value
|
||||||
await self.storage_relationship(relationship)
|
await self.storage_relationship(relationship)
|
||||||
relationship.saved = True
|
relationship.saved = True
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from nonebot import get_driver
|
|
||||||
|
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
@@ -15,9 +14,6 @@ topic_config = LogConfig(
|
|||||||
|
|
||||||
logger = get_module_logger("topic_identifier", config=topic_config)
|
logger = get_module_logger("topic_identifier", config=topic_config)
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
|
|
||||||
class TopicIdentifier:
|
class TopicIdentifier:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -7,20 +7,17 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import jieba
|
import jieba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from nonebot import get_driver
|
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from ..utils.typo_generator import ChineseTypoGenerator
|
from ..utils.typo_generator import ChineseTypoGenerator
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .message import MessageRecv, Message
|
from .message import MessageRecv, Message
|
||||||
from .message_base import UserInfo
|
from ..message.message_base import UserInfo
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from ..moods.moods import MoodManager
|
from ..moods.moods import MoodManager
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
logger = get_module_logger("chat_utils")
|
logger = get_module_logger("chat_utils")
|
||||||
|
|
||||||
@@ -364,10 +361,10 @@ def random_remove_punctuation(text: str) -> str:
|
|||||||
def process_llm_response(text: str) -> List[str]:
|
def process_llm_response(text: str) -> List[str]:
|
||||||
# processed_response = process_text_with_typos(content)
|
# processed_response = process_text_with_typos(content)
|
||||||
# 对西文字符段落的回复长度设置为汉字字符的两倍
|
# 对西文字符段落的回复长度设置为汉字字符的两倍
|
||||||
if len(text) > 100 and not is_western_paragraph(text) :
|
if len(text) > 100 and not is_western_paragraph(text):
|
||||||
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
|
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
|
||||||
return ["懒得说"]
|
return ["懒得说"]
|
||||||
elif len(text) > 200 :
|
elif len(text) > 200:
|
||||||
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
|
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
|
||||||
return ["懒得说"]
|
return ["懒得说"]
|
||||||
# 处理长消息
|
# 处理长消息
|
||||||
@@ -533,9 +530,9 @@ def recover_kaomoji(sentences, placeholder_to_kaomoji):
|
|||||||
|
|
||||||
def is_western_char(char):
|
def is_western_char(char):
|
||||||
"""检测是否为西文字符"""
|
"""检测是否为西文字符"""
|
||||||
return len(char.encode('utf-8')) <= 2
|
return len(char.encode("utf-8")) <= 2
|
||||||
|
|
||||||
|
|
||||||
def is_western_paragraph(paragraph):
|
def is_western_paragraph(paragraph):
|
||||||
"""检测是否为西文字符段落"""
|
"""检测是否为西文字符段落"""
|
||||||
return all(is_western_char(char) for char in paragraph if char.isalnum())
|
return all(is_western_char(char) for char in paragraph if char.isalnum())
|
||||||
|
|
||||||
@@ -6,7 +6,6 @@ from typing import Optional
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
|
|
||||||
from nonebot import get_driver
|
|
||||||
|
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
@@ -16,9 +15,6 @@ from src.common.logger import get_module_logger
|
|||||||
|
|
||||||
logger = get_module_logger("chat_image")
|
logger = get_module_logger("chat_image")
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
|
|
||||||
class ImageManager:
|
class ImageManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
|
|||||||
@@ -1,11 +1 @@
|
|||||||
from nonebot import get_app
|
|
||||||
from .api import router
|
|
||||||
from src.common.logger import get_module_logger
|
|
||||||
|
|
||||||
# 获取主应用实例并挂载路由
|
|
||||||
app = get_app()
|
|
||||||
app.include_router(router, prefix="/api")
|
|
||||||
|
|
||||||
# 打印日志,方便确认API已注册
|
|
||||||
logger = get_module_logger("cfg_reload")
|
|
||||||
logger.success("配置重载API已注册,可通过 /api/reload-config 访问")
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import re
|
|||||||
import jieba
|
import jieba
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
from nonebot import get_driver
|
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
from ..chat.utils import (
|
from ..chat.utils import (
|
||||||
@@ -238,7 +237,7 @@ class Hippocampus:
|
|||||||
n_hours2=global_config.memory_build_distribution[3], # 第二个分布均值(24小时前)
|
n_hours2=global_config.memory_build_distribution[3], # 第二个分布均值(24小时前)
|
||||||
std_hours2=global_config.memory_build_distribution[4], # 第二个分布标准差
|
std_hours2=global_config.memory_build_distribution[4], # 第二个分布标准差
|
||||||
weight2=global_config.memory_build_distribution[5], # 第二个分布权重 40%
|
weight2=global_config.memory_build_distribution[5], # 第二个分布权重 40%
|
||||||
total_samples=global_config.build_memory_sample_num # 总共生成10个时间点
|
total_samples=global_config.build_memory_sample_num, # 总共生成10个时间点
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成时间戳数组
|
# 生成时间戳数组
|
||||||
@@ -250,9 +249,7 @@ class Hippocampus:
|
|||||||
chat_samples = []
|
chat_samples = []
|
||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
messages = self.random_get_msg_snippet(
|
messages = self.random_get_msg_snippet(
|
||||||
timestamp,
|
timestamp, global_config.build_memory_sample_length, max_memorized_time_per_msg
|
||||||
global_config.build_memory_sample_length,
|
|
||||||
max_memorized_time_per_msg
|
|
||||||
)
|
)
|
||||||
if messages:
|
if messages:
|
||||||
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
|
||||||
@@ -297,16 +294,16 @@ class Hippocampus:
|
|||||||
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num))
|
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num))
|
||||||
|
|
||||||
# 使用正则表达式提取<>中的内容
|
# 使用正则表达式提取<>中的内容
|
||||||
topics = re.findall(r'<([^>]+)>', topics_response[0])
|
topics = re.findall(r"<([^>]+)>", topics_response[0])
|
||||||
|
|
||||||
# 如果没有找到<>包裹的内容,返回['none']
|
# 如果没有找到<>包裹的内容,返回['none']
|
||||||
if not topics:
|
if not topics:
|
||||||
topics = ['none']
|
topics = ["none"]
|
||||||
else:
|
else:
|
||||||
# 处理提取出的话题
|
# 处理提取出的话题
|
||||||
topics = [
|
topics = [
|
||||||
topic.strip()
|
topic.strip()
|
||||||
for topic in ','.join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||||
if topic.strip()
|
if topic.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -314,8 +311,7 @@ class Hippocampus:
|
|||||||
# any()检查topic中是否包含任何一个filter_keywords中的关键词
|
# any()检查topic中是否包含任何一个filter_keywords中的关键词
|
||||||
# 只保留不包含禁用关键词的topic
|
# 只保留不包含禁用关键词的topic
|
||||||
filtered_topics = [
|
filtered_topics = [
|
||||||
topic for topic in topics
|
topic for topic in topics if not any(keyword in topic for keyword in global_config.memory_ban_words)
|
||||||
if not any(keyword in topic for keyword in global_config.memory_ban_words)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.debug(f"过滤后话题: {filtered_topics}")
|
logger.debug(f"过滤后话题: {filtered_topics}")
|
||||||
@@ -445,8 +441,7 @@ class Hippocampus:
|
|||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.success(
|
logger.success(
|
||||||
f"--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} "
|
f"--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒--------------------------"
|
||||||
"秒--------------------------"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def sync_memory_to_db(self):
|
def sync_memory_to_db(self):
|
||||||
@@ -800,16 +795,16 @@ class Hippocampus:
|
|||||||
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 4))
|
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 4))
|
||||||
# 使用正则表达式提取<>中的内容
|
# 使用正则表达式提取<>中的内容
|
||||||
print(f"话题: {topics_response[0]}")
|
print(f"话题: {topics_response[0]}")
|
||||||
topics = re.findall(r'<([^>]+)>', topics_response[0])
|
topics = re.findall(r"<([^>]+)>", topics_response[0])
|
||||||
|
|
||||||
# 如果没有找到<>包裹的内容,返回['none']
|
# 如果没有找到<>包裹的内容,返回['none']
|
||||||
if not topics:
|
if not topics:
|
||||||
topics = ['none']
|
topics = ["none"]
|
||||||
else:
|
else:
|
||||||
# 处理提取出的话题
|
# 处理提取出的话题
|
||||||
topics = [
|
topics = [
|
||||||
topic.strip()
|
topic.strip()
|
||||||
for topic in ','.join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||||
if topic.strip()
|
if topic.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -994,9 +989,6 @@ def segment_text(text):
|
|||||||
return seg_text
|
return seg_text
|
||||||
|
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 创建记忆图
|
# 创建记忆图
|
||||||
|
|||||||
26
src/plugins/message/__init__.py
Normal file
26
src/plugins/message/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""Maim Message - A message handling library"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
from .api import BaseMessageAPI, global_api
|
||||||
|
from .message_base import (
|
||||||
|
Seg,
|
||||||
|
GroupInfo,
|
||||||
|
UserInfo,
|
||||||
|
FormatInfo,
|
||||||
|
TemplateInfo,
|
||||||
|
BaseMessageInfo,
|
||||||
|
MessageBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseMessageAPI",
|
||||||
|
"Seg",
|
||||||
|
"global_api",
|
||||||
|
"GroupInfo",
|
||||||
|
"UserInfo",
|
||||||
|
"FormatInfo",
|
||||||
|
"TemplateInfo",
|
||||||
|
"BaseMessageInfo",
|
||||||
|
"MessageBase",
|
||||||
|
]
|
||||||
86
src/plugins/message/api.py
Normal file
86
src/plugins/message/api.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from typing import Optional, Dict, Any, Callable, List
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import uvicorn
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMessageAPI:
|
||||||
|
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
|
||||||
|
self.app = FastAPI()
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.message_handlers: List[Callable] = []
|
||||||
|
self._setup_routes()
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
def _setup_routes(self):
|
||||||
|
"""设置基础路由"""
|
||||||
|
|
||||||
|
@self.app.post("/api/message")
|
||||||
|
async def handle_message(message: Dict[str, Any]):
|
||||||
|
# try:
|
||||||
|
for handler in self.message_handlers:
|
||||||
|
await handler(message)
|
||||||
|
return {"status": "success"}
|
||||||
|
# except Exception as e:
|
||||||
|
# raise HTTPException(status_code=500, detail=str(e)) from e
|
||||||
|
|
||||||
|
def register_message_handler(self, handler: Callable):
|
||||||
|
"""注册消息处理函数"""
|
||||||
|
self.message_handlers.append(handler)
|
||||||
|
|
||||||
|
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""发送消息到指定端点"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
|
||||||
|
return await response.json()
|
||||||
|
except Exception as e:
|
||||||
|
# logger.error(f"发送消息失败: {str(e)}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run_sync(self):
|
||||||
|
"""同步方式运行服务器"""
|
||||||
|
uvicorn.run(self.app, host=self.host, port=self.port)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""异步方式运行服务器"""
|
||||||
|
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
||||||
|
self.server = uvicorn.Server(config)
|
||||||
|
|
||||||
|
await self.server.serve()
|
||||||
|
|
||||||
|
async def start_server(self):
|
||||||
|
"""启动服务器的异步方法"""
|
||||||
|
if not self._running:
|
||||||
|
self._running = True
|
||||||
|
await self.run()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""停止服务器"""
|
||||||
|
if hasattr(self, "server"):
|
||||||
|
self._running = False
|
||||||
|
# 正确关闭 uvicorn 服务器
|
||||||
|
self.server.should_exit = True
|
||||||
|
await self.server.shutdown()
|
||||||
|
# 等待服务器完全停止
|
||||||
|
if hasattr(self.server, "started") and self.server.started:
|
||||||
|
await self.server.main_loop()
|
||||||
|
# 清理处理程序
|
||||||
|
self.message_handlers.clear()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""启动服务器的便捷方法"""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(self.start_server())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
global_api = BaseMessageAPI(host=os.environ["HOST"], port=os.environ["PORT"])
|
||||||
@@ -103,6 +103,63 @@ class UserInfo:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FormatInfo:
|
||||||
|
"""格式信息类"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
目前maimcore可接受的格式为text,image,emoji
|
||||||
|
可发送的格式为text,emoji,reply
|
||||||
|
"""
|
||||||
|
|
||||||
|
content_format: Optional[str] = None
|
||||||
|
accept_format: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
"""转换为字典格式"""
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict) -> "FormatInfo":
|
||||||
|
"""从字典创建FormatInfo实例
|
||||||
|
Args:
|
||||||
|
data: 包含必要字段的字典
|
||||||
|
Returns:
|
||||||
|
FormatInfo: 新的实例
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
content_format=data.get("content_format"),
|
||||||
|
accept_format=data.get("accept_format"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TemplateInfo:
|
||||||
|
"""模板信息类"""
|
||||||
|
|
||||||
|
template_items: Optional[List[Dict]] = None
|
||||||
|
template_name: Optional[str] = None
|
||||||
|
template_default: bool = True
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
"""转换为字典格式"""
|
||||||
|
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict) -> "TemplateInfo":
|
||||||
|
"""从字典创建TemplateInfo实例
|
||||||
|
Args:
|
||||||
|
data: 包含必要字段的字典
|
||||||
|
Returns:
|
||||||
|
TemplateInfo: 新的实例
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
template_items=data.get("template_items"),
|
||||||
|
template_name=data.get("template_name"),
|
||||||
|
template_default=data.get("template_default", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseMessageInfo:
|
class BaseMessageInfo:
|
||||||
"""消息信息类"""
|
"""消息信息类"""
|
||||||
@@ -112,13 +169,15 @@ class BaseMessageInfo:
|
|||||||
time: Optional[int] = None
|
time: Optional[int] = None
|
||||||
group_info: Optional[GroupInfo] = None
|
group_info: Optional[GroupInfo] = None
|
||||||
user_info: Optional[UserInfo] = None
|
user_info: Optional[UserInfo] = None
|
||||||
|
format_info: Optional[FormatInfo] = None
|
||||||
|
template_info: Optional[TemplateInfo] = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
def to_dict(self) -> Dict:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
result = {}
|
result = {}
|
||||||
for field, value in asdict(self).items():
|
for field, value in asdict(self).items():
|
||||||
if value is not None:
|
if value is not None:
|
||||||
if isinstance(value, (GroupInfo, UserInfo)):
|
if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)):
|
||||||
result[field] = value.to_dict()
|
result[field] = value.to_dict()
|
||||||
else:
|
else:
|
||||||
result[field] = value
|
result[field] = value
|
||||||
@@ -136,12 +195,16 @@ class BaseMessageInfo:
|
|||||||
"""
|
"""
|
||||||
group_info = GroupInfo.from_dict(data.get("group_info", {}))
|
group_info = GroupInfo.from_dict(data.get("group_info", {}))
|
||||||
user_info = UserInfo.from_dict(data.get("user_info", {}))
|
user_info = UserInfo.from_dict(data.get("user_info", {}))
|
||||||
|
format_info = FormatInfo.from_dict(data.get("format_info", {}))
|
||||||
|
template_info = TemplateInfo.from_dict(data.get("template_info", {}))
|
||||||
return cls(
|
return cls(
|
||||||
platform=data.get("platform"),
|
platform=data.get("platform"),
|
||||||
message_id=data.get("message_id"),
|
message_id=data.get("message_id"),
|
||||||
time=data.get("time"),
|
time=data.get("time"),
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
user_info=user_info,
|
user_info=user_info,
|
||||||
|
format_info=format_info,
|
||||||
|
template_info=template_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
98
src/plugins/message/test.py
Normal file
98
src/plugins/message/test.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import unittest
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from api import BaseMessageAPI
|
||||||
|
from message_base import (
|
||||||
|
BaseMessageInfo,
|
||||||
|
UserInfo,
|
||||||
|
GroupInfo,
|
||||||
|
FormatInfo,
|
||||||
|
TemplateInfo,
|
||||||
|
MessageBase,
|
||||||
|
Seg,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
send_url = "http://localhost"
|
||||||
|
receive_port = 18002 # 接收消息的端口
|
||||||
|
send_port = 18000 # 发送消息的端口
|
||||||
|
test_endpoint = "/api/message"
|
||||||
|
|
||||||
|
# 创建并启动API实例
|
||||||
|
api = BaseMessageAPI(host="0.0.0.0", port=receive_port)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLiveAPI(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
"""测试前的设置"""
|
||||||
|
self.received_messages = []
|
||||||
|
|
||||||
|
async def message_handler(message):
|
||||||
|
self.received_messages.append(message)
|
||||||
|
|
||||||
|
self.api = api
|
||||||
|
self.api.register_message_handler(message_handler)
|
||||||
|
self.server_task = asyncio.create_task(self.api.run())
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(asyncio.sleep(1), timeout=5)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.skipTest("服务器启动超时")
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
"""测试后的清理"""
|
||||||
|
if hasattr(self, "server_task"):
|
||||||
|
await self.api.stop() # 先调用正常的停止流程
|
||||||
|
if not self.server_task.done():
|
||||||
|
self.server_task.cancel()
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self.server_task, timeout=100)
|
||||||
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def test_send_and_receive_message(self):
|
||||||
|
"""测试向运行中的API发送消息并接收响应"""
|
||||||
|
# 准备测试消息
|
||||||
|
user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq")
|
||||||
|
group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq")
|
||||||
|
format_info = FormatInfo(
|
||||||
|
content_format=["text"], accept_format=["text", "emoji", "reply"]
|
||||||
|
)
|
||||||
|
template_info = None
|
||||||
|
message_info = BaseMessageInfo(
|
||||||
|
platform="qq",
|
||||||
|
message_id=12345678,
|
||||||
|
time=12345678,
|
||||||
|
group_info=group_info,
|
||||||
|
user_info=user_info,
|
||||||
|
format_info=format_info,
|
||||||
|
template_info=template_info,
|
||||||
|
)
|
||||||
|
message = MessageBase(
|
||||||
|
message_info=message_info,
|
||||||
|
raw_message="测试消息",
|
||||||
|
message_segment=Seg(type="text", data="测试消息"),
|
||||||
|
)
|
||||||
|
test_message = message.to_dict()
|
||||||
|
|
||||||
|
# 发送测试消息到发送端口
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{send_url}:{send_port}{test_endpoint}",
|
||||||
|
json=test_message,
|
||||||
|
) as response:
|
||||||
|
response_data = await response.json()
|
||||||
|
self.assertEqual(response.status, 200)
|
||||||
|
self.assertEqual(response_data["status"], "success")
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(5): # 设置5秒超时
|
||||||
|
while len(self.received_messages) == 0:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
received_message = self.received_messages[0]
|
||||||
|
print(received_message)
|
||||||
|
self.received_messages.clear()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.fail("等待接收消息超时")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -6,15 +6,13 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from nonebot import get_driver
|
|
||||||
import base64
|
import base64
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
logger = get_module_logger("model_utils")
|
logger = get_module_logger("model_utils")
|
||||||
|
|
||||||
@@ -34,8 +32,9 @@ class LLM_request:
|
|||||||
def __init__(self, model, **kwargs):
|
def __init__(self, model, **kwargs):
|
||||||
# 将大写的配置键转换为小写并从config中获取实际值
|
# 将大写的配置键转换为小写并从config中获取实际值
|
||||||
try:
|
try:
|
||||||
self.api_key = getattr(config, model["key"])
|
self.api_key = os.environ[model["key"]]
|
||||||
self.base_url = getattr(config, model["base_url"])
|
self.base_url = os.environ[model["base_url"]]
|
||||||
|
print(self.api_key, self.base_url)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
logger.error(f"原始 model dict 信息:{model}")
|
logger.error(f"原始 model dict 信息:{model}")
|
||||||
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import json
|
|||||||
import re
|
import re
|
||||||
from typing import Dict, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
from nonebot import get_driver
|
|
||||||
|
|
||||||
# 添加项目根目录到 Python 路径
|
# 添加项目根目录到 Python 路径
|
||||||
|
|
||||||
@@ -14,9 +13,6 @@ from src.common.logger import get_module_logger
|
|||||||
|
|
||||||
logger = get_module_logger("scheduler")
|
logger = get_module_logger("scheduler")
|
||||||
|
|
||||||
driver = get_driver()
|
|
||||||
config = driver.config
|
|
||||||
|
|
||||||
|
|
||||||
class ScheduleGenerator:
|
class ScheduleGenerator:
|
||||||
enable_output: bool = True
|
enable_output: bool = True
|
||||||
@@ -183,5 +179,7 @@ class ScheduleGenerator:
|
|||||||
logger.info(f"时间[{time_str}]: 活动[{activity}]")
|
logger.info(f"时间[{time_str}]: 活动[{activity}]")
|
||||||
logger.info("==================")
|
logger.info("==================")
|
||||||
self.enable_output = False
|
self.enable_output = False
|
||||||
|
|
||||||
|
|
||||||
# 当作为组件导入时使用的实例
|
# 当作为组件导入时使用的实例
|
||||||
bot_schedule = ScheduleGenerator()
|
bot_schedule = ScheduleGenerator()
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
更新版本后,建议删除数据库messages中所有内容,不然会出现报错
|
|
||||||
该操作不会影响你的记忆
|
|
||||||
|
|
||||||
如果显示配置文件版本过低,运行根目录的bat
|
|
||||||
Reference in New Issue
Block a user