refactor: 初步重构为maimcore

This commit is contained in:
tcmofashi
2025-03-27 13:30:46 +08:00
parent 09c6500d79
commit 4c332d0b2f
26 changed files with 426 additions and 1213 deletions

61
bot.py
View File

@@ -4,15 +4,11 @@ import os
import shutil import shutil
import sys import sys
from pathlib import Path from pathlib import Path
import nonebot
import time import time
import uvicorn
from dotenv import load_dotenv
from nonebot.adapters.onebot.v11 import Adapter
import platform import platform
from dotenv import load_dotenv
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.main import MainSystem
logger = get_module_logger("main_bot") logger = get_module_logger("main_bot")
@@ -134,11 +130,7 @@ def scan_provider(env_config: dict):
async def graceful_shutdown(): async def graceful_shutdown():
try: try:
global uvicorn_server logger.info("正在优雅关闭麦麦...")
if uvicorn_server:
uvicorn_server.force_exit = True # 强制退出
await uvicorn_server.shutdown()
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks: for task in tasks:
task.cancel() task.cancel()
@@ -148,22 +140,6 @@ async def graceful_shutdown():
logger.error(f"麦麦关闭失败: {e}") logger.error(f"麦麦关闭失败: {e}")
async def uvicorn_main():
global uvicorn_server
config = uvicorn.Config(
app="__main__:app",
host=os.getenv("HOST", "127.0.0.1"),
port=int(os.getenv("PORT", 8080)),
reload=os.getenv("ENVIRONMENT") == "dev",
timeout_graceful_shutdown=5,
log_config=None,
access_log=False,
)
server = uvicorn.Server(config)
uvicorn_server = server
await server.serve()
def check_eula(): def check_eula():
eula_confirm_file = Path("eula.confirmed") eula_confirm_file = Path("eula.confirmed")
privacy_confirm_file = Path("privacy.confirmed") privacy_confirm_file = Path("privacy.confirmed")
@@ -245,7 +221,6 @@ def check_eula():
def raw_main(): def raw_main():
# 利用 TZ 环境变量设定程序工作的时区 # 利用 TZ 环境变量设定程序工作的时区
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
if platform.system().lower() != "windows": if platform.system().lower() != "windows":
time.tzset() time.tzset()
@@ -256,40 +231,26 @@ def raw_main():
init_env() init_env()
load_env() load_env()
# load_logger()
env_config = {key: os.getenv(key) for key in os.environ} env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config) scan_provider(env_config)
# 设置基础配置 # 返回MainSystem实例
base_config = { return MainSystem()
"websocket_port": int(env_config.get("PORT", 8080)),
"host": env_config.get("HOST", "127.0.0.1"),
"log_level": "INFO",
}
# 合并配置
nonebot.init(**base_config, **env_config)
# 注册适配器
global driver
driver = nonebot.get_driver()
driver.register_adapter(Adapter)
# 加载插件
nonebot.load_plugins("src/plugins")
if __name__ == "__main__": if __name__ == "__main__":
try: try:
raw_main() # 获取MainSystem实例
main_system = raw_main()
app = nonebot.get_asgi() # 创建事件循环
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
loop.run_until_complete(uvicorn_main()) # 执行初始化和任务调度
loop.run_until_complete(main_system.initialize())
loop.run_until_complete(main_system.schedule_tasks())
except KeyboardInterrupt: except KeyboardInterrupt:
logger.warning("收到中断信号,正在优雅关闭...") logger.warning("收到中断信号,正在优雅关闭...")
loop.run_until_complete(graceful_shutdown()) loop.run_until_complete(graceful_shutdown())

View File

@@ -1,21 +1,19 @@
import asyncio import asyncio
import time import time
from datetime import datetime from datetime import datetime
from .plugins.utils.statistic import LLMStatistics
from plugins.utils.statistic import LLMStatistics from .plugins.moods.moods import MoodManager
from plugins.moods.moods import MoodManager from .plugins.schedule.schedule_generator import bot_schedule
from plugins.schedule.schedule_generator import bot_schedule from .plugins.chat.emoji_manager import emoji_manager
from plugins.chat.emoji_manager import emoji_manager from .plugins.chat.relationship_manager import relationship_manager
from plugins.chat.relationship_manager import relationship_manager from .plugins.willing.willing_manager import willing_manager
from plugins.willing.willing_manager import willing_manager from .plugins.chat.chat_stream import chat_manager
from plugins.chat.chat_stream import chat_manager from .plugins.memory_system.memory import hippocampus
from plugins.memory_system.memory import hippocampus from .plugins.chat.message_sender import message_manager
from plugins.chat.message_sender import message_manager from .plugins.chat.storage import MessageStorage
from plugins.chat.storage import MessageStorage from .plugins.chat.config import global_config
from plugins.chat.config import global_config from .plugins.chat.bot import chat_bot
from common.logger import get_module_logger from .common.logger import get_module_logger
from fastapi import FastAPI
from plugins.chat.api import app as api_app
logger = get_module_logger("main") logger = get_module_logger("main")
@@ -25,13 +23,29 @@ class MainSystem:
self.llm_stats = LLMStatistics("llm_statistics.txt") self.llm_stats = LLMStatistics("llm_statistics.txt")
self.mood_manager = MoodManager.get_instance() self.mood_manager = MoodManager.get_instance()
self._message_manager_started = False self._message_manager_started = False
self.app = FastAPI()
self.app.mount("/chat", api_app) # 使用消息API替代直接的FastAPI实例
from .plugins.message import global_api
self.app = global_api
async def initialize(self): async def initialize(self):
"""初始化系统组件""" """初始化系统组件"""
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......") logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
# 启动API服务器改为异步启动
self.api_task = asyncio.create_task(self.app.run())
# 其他初始化任务
await asyncio.gather(
self._init_components(), # 将原有的初始化代码移到这个新方法中
# api_task,
)
logger.success("系统初始化完成")
async def _init_components(self):
"""初始化其他组件"""
# 启动LLM统计 # 启动LLM统计
self.llm_stats.start() self.llm_stats.start()
logger.success("LLM统计功能启动成功") logger.success("LLM统计功能启动成功")
@@ -64,10 +78,7 @@ class MainSystem:
bot_schedule.print_schedule() bot_schedule.print_schedule()
# 启动FastAPI服务器 # 启动FastAPI服务器
import uvicorn self.app.register_message_handler(chat_bot.message_process)
uvicorn.run(self.app, host="0.0.0.0", port=18000)
logger.success("API服务器启动成功")
async def schedule_tasks(self): async def schedule_tasks(self):
"""调度定时任务""" """调度定时任务"""
@@ -86,6 +97,7 @@ class MainSystem:
async def build_memory_task(self): async def build_memory_task(self):
"""记忆构建任务""" """记忆构建任务"""
while True: while True:
logger.info("正在进行记忆构建")
await hippocampus.operation_build_memory() await hippocampus.operation_build_memory()
await asyncio.sleep(global_config.build_memory_interval) await asyncio.sleep(global_config.build_memory_interval)
@@ -100,6 +112,7 @@ class MainSystem:
async def merge_memory_task(self): async def merge_memory_task(self):
"""记忆整合任务""" """记忆整合任务"""
while True: while True:
logger.info("正在进行记忆整合")
await asyncio.sleep(global_config.build_memory_interval + 10) await asyncio.sleep(global_config.build_memory_interval + 10)
async def print_mood_task(self): async def print_mood_task(self):
@@ -130,8 +143,9 @@ class MainSystem:
async def main(): async def main():
"""主函数""" """主函数"""
system = MainSystem() system = MainSystem()
await system.initialize() await asyncio.gather(system.initialize(), system.schedule_tasks(), system.api_task)
await system.schedule_tasks() # await system.initialize()
# await system.schedule_tasks()
if __name__ == "__main__": if __name__ == "__main__":

23
src/plugins/__init__.py Normal file
View File

@@ -0,0 +1,23 @@
"""
MaiMBot插件系统
包含聊天、情绪、记忆、日程等功能模块
"""
from .chat.chat_stream import chat_manager
from .chat.emoji_manager import emoji_manager
from .chat.relationship_manager import relationship_manager
from .moods.moods import MoodManager
from .willing.willing_manager import willing_manager
from .memory_system.memory import hippocampus
from .schedule.schedule_generator import bot_schedule
# 导出主要组件供外部使用
__all__ = [
"chat_manager",
"emoji_manager",
"relationship_manager",
"MoodManager",
"willing_manager",
"hippocampus",
"bot_schedule",
]

View File

@@ -1,154 +1,15 @@
import asyncio
import time
from nonebot import get_driver, on_message, on_notice, require
from nonebot.adapters.onebot.v11 import Bot, MessageEvent, NoticeEvent
from nonebot.typing import T_State
from ..moods.moods import MoodManager # 导入情绪管理器
from ..schedule.schedule_generator import bot_schedule
from ..utils.statistic import LLMStatistics
from .bot import chat_bot
from .config import global_config
from .emoji_manager import emoji_manager from .emoji_manager import emoji_manager
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from ..willing.willing_manager import willing_manager
from .chat_stream import chat_manager from .chat_stream import chat_manager
from ..memory_system.memory import hippocampus from .message_sender import message_manager
from .message_sender import message_manager, message_sender
from .storage import MessageStorage from .storage import MessageStorage
from src.common.logger import get_module_logger from .config import global_config
logger = get_module_logger("chat_init") __all__ = [
"emoji_manager",
# 创建LLM统计实例 "relationship_manager",
llm_stats = LLMStatistics("llm_statistics.txt") "chat_manager",
"message_manager",
# 添加标志变量 "MessageStorage",
_message_manager_started = False "global_config",
]
# 获取驱动器
driver = get_driver()
config = driver.config
# 初始化表情管理器
emoji_manager.initialize()
logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
# 注册消息处理器
msg_in = on_message(priority=5)
# 注册和bot相关的通知处理器
notice_matcher = on_notice(priority=1)
# 创建定时任务
scheduler = require("nonebot_plugin_apscheduler").scheduler
@driver.on_startup
async def start_background_tasks():
"""启动后台任务"""
# 启动LLM统计
llm_stats.start()
logger.success("LLM统计功能启动成功")
# 初始化并启动情绪管理器
mood_manager = MoodManager.get_instance()
mood_manager.start_mood_update(update_interval=global_config.mood_update_interval)
logger.success("情绪管理器启动成功")
# 只启动表情包管理任务
asyncio.create_task(emoji_manager.start_periodic_check(interval_MINS=global_config.EMOJI_CHECK_INTERVAL))
await bot_schedule.initialize()
bot_schedule.print_schedule()
@driver.on_startup
async def init_relationships():
"""在 NoneBot2 启动时初始化关系管理器"""
logger.debug("正在加载用户关系数据...")
await relationship_manager.load_all_relationships()
asyncio.create_task(relationship_manager._start_relationship_manager())
@driver.on_bot_connect
async def _(bot: Bot):
"""Bot连接成功时的处理"""
global _message_manager_started
logger.debug(f"-----------{global_config.BOT_NICKNAME}成功连接!-----------")
await willing_manager.ensure_started()
message_sender.set_bot(bot)
logger.success("-----------消息发送器已启动!-----------")
if not _message_manager_started:
asyncio.create_task(message_manager.start_processor())
_message_manager_started = True
logger.success("-----------消息处理器已启动!-----------")
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
logger.success("-----------开始偷表情包!-----------")
asyncio.create_task(chat_manager._initialize())
asyncio.create_task(chat_manager._auto_save_task())
@msg_in.handle()
async def _(bot: Bot, event: MessageEvent, state: T_State):
# 处理合并转发消息
if "forward" in event.message:
await chat_bot.handle_forward_message(event, bot)
else:
await chat_bot.handle_message(event, bot)
@notice_matcher.handle()
async def _(bot: Bot, event: NoticeEvent, state: T_State):
logger.debug(f"收到通知:{event}")
await chat_bot.handle_notice(event, bot)
# 添加build_memory定时任务
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task():
"""每build_memory_interval秒执行一次记忆构建"""
await hippocampus.operation_build_memory()
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")
async def forget_memory_task():
"""每30秒执行一次记忆构建"""
print("\033[1;32m[记忆遗忘]\033[0m 开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=global_config.memory_forget_percentage)
print("\033[1;32m[记忆遗忘]\033[0m 记忆遗忘完成")
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval + 10, id="merge_memory")
async def merge_memory_task():
"""每30秒执行一次记忆构建"""
# print("\033[1;32m[记忆整合]\033[0m 开始整合")
# await hippocampus.operation_merge_memory(percentage=0.1)
# print("\033[1;32m[记忆整合]\033[0m 记忆整合完成")
@scheduler.scheduled_job("interval", seconds=30, id="print_mood")
async def print_mood_task():
"""每30秒打印一次情绪状态"""
mood_manager = MoodManager.get_instance()
mood_manager.print_mood_status()
@scheduler.scheduled_job("interval", seconds=7200, id="generate_schedule")
async def generate_schedule_task():
"""每2小时尝试生成一次日程"""
logger.debug("尝试生成日程")
await bot_schedule.initialize()
if not bot_schedule.enable_output:
bot_schedule.print_schedule()
@scheduler.scheduled_job("interval", seconds=3600, id="remove_recalled_message")
async def remove_recalled_message() -> None:
"""删除撤回消息"""
try:
storage = MessageStorage()
await storage.remove_recalled_message(time.time())
except Exception:
logger.exception("删除撤回消息失败")

View File

@@ -1,54 +0,0 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Dict, Any
from .bot import chat_bot
from .message_cq import MessageRecvCQ
from .message_base import UserInfo, GroupInfo
from src.common.logger import get_module_logger
logger = get_module_logger("chat_api")
app = FastAPI()
class MessageRequest(BaseModel):
message_id: int
user_info: Dict[str, Any]
raw_message: str
group_info: Optional[Dict[str, Any]] = None
reply_message: Optional[Dict[str, Any]] = None
platform: str = "api"
@app.post("/api/message")
async def handle_message(message: MessageRequest):
try:
user_info = UserInfo(
user_id=message.user_info["user_id"],
user_nickname=message.user_info["user_nickname"],
user_cardname=message.user_info.get("user_cardname"),
platform=message.platform,
)
group_info = None
if message.group_info:
group_info = GroupInfo(
group_id=message.group_info["group_id"],
group_name=message.group_info.get("group_name"),
platform=message.platform,
)
message_cq = MessageRecvCQ(
message_id=message.message_id,
user_info=user_info,
raw_message=message.raw_message,
group_info=group_info,
reply_message=message.reply_message,
platform=message.platform,
)
await chat_bot.message_process(message_cq)
return {"status": "success"}
except Exception as e:
logger.exception("API处理消息时出错")
raise HTTPException(status_code=500, detail=str(e)) from e

View File

@@ -1,16 +1,7 @@
import re import re
import time import time
from random import random from random import random
from nonebot.adapters.onebot.v11 import ( import json
Bot,
MessageEvent,
PrivateMessageEvent,
GroupMessageEvent,
NoticeEvent,
PokeNotifyEvent,
GroupRecallNoticeEvent,
FriendRecallNoticeEvent,
)
from ..memory_system.memory import hippocampus from ..memory_system.memory import hippocampus
from ..moods.moods import MoodManager # 导入情绪管理器 from ..moods.moods import MoodManager # 导入情绪管理器
@@ -18,9 +9,7 @@ from .config import global_config
from .emoji_manager import emoji_manager # 导入表情包管理器 from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator from .llm_generator import ResponseGenerator
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
from .message_cq import (
MessageRecvCQ,
)
from .chat_stream import chat_manager from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器 from .message_sender import message_manager # 导入新的消息管理器
@@ -30,7 +19,7 @@ from .utils import is_mentioned_bot_in_message
from .utils_image import image_path_to_base64 from .utils_image import image_path_to_base64
from .utils_user import get_user_nickname, get_user_cardname from .utils_user import get_user_nickname, get_user_cardname
from ..willing.willing_manager import willing_manager # 导入意愿管理器 from ..willing.willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg from ..message import UserInfo, GroupInfo, Seg
from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig from src.common.logger import get_module_logger, CHAT_STYLE_CONFIG, LogConfig
@@ -62,7 +51,7 @@ class ChatBot:
if not self._started: if not self._started:
self._started = True self._started = True
async def message_process(self, message_cq: MessageRecvCQ) -> None: async def message_process(self, message_data: str) -> None:
"""处理转化后的统一格式消息 """处理转化后的统一格式消息
1. 过滤消息 1. 过滤消息
2. 记忆激活 2. 记忆激活
@@ -71,12 +60,11 @@ class ChatBot:
5. 更新关系 5. 更新关系
6. 更新情绪 6. 更新情绪
""" """
await message_cq.initialize() # message_json = json.loads(message_data)
message_json = message_cq.to_dict()
# 哦我嘞个json # 哦我嘞个json
# 进入maimbot # 进入maimbot
message = MessageRecv(message_json) message = MessageRecv(message_data)
groupinfo = message.message_info.group_info groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info userinfo = message.message_info.user_info
messageinfo = message.message_info messageinfo = message.message_info
@@ -146,7 +134,7 @@ class ChatBot:
response = None response = None
# 开始组织语言 # 开始组织语言
if random() < reply_probability: if random() < reply_probability + 100:
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.BOT_QQ, user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME, user_nickname=global_config.BOT_NICKNAME,
@@ -278,235 +266,6 @@ class ChatBot:
# chat_stream=chat # chat_stream=chat
# ) # )
async def handle_notice(self, event: NoticeEvent, bot: Bot) -> None:
"""处理收到的通知"""
if isinstance(event, PokeNotifyEvent):
# 戳一戳 通知
# 不处理其他人的戳戳
if not event.is_tome():
return
# 用户屏蔽,不区分私聊/群聊
if event.user_id in global_config.ban_user_id:
return
# 白名单模式
if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
raw_message = f"[戳了戳]{global_config.BOT_NICKNAME}" # 默认类型
if info := event.model_extra["raw_info"]:
poke_type = info[2].get("txt", "戳了戳") # 戳戳类型,例如“拍一拍”、“揉一揉”、“捏一捏”
custom_poke_message = info[4].get("txt", "") # 自定义戳戳消息,若不存在会为空字符串
raw_message = f"[{poke_type}]{global_config.BOT_NICKNAME}{custom_poke_message}"
raw_message += "(这是一个类似摸摸头的友善行为,而不是恶意行为,请不要作出攻击发言)"
user_info = UserInfo(
user_id=event.user_id,
user_nickname=(await bot.get_stranger_info(user_id=event.user_id, no_cache=True))["nickname"],
user_cardname=None,
platform="qq",
)
if event.group_id:
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
else:
group_info = None
message_cq = MessageRecvCQ(
message_id=0,
user_info=user_info,
raw_message=str(raw_message),
group_info=group_info,
reply_message=None,
platform="qq",
)
await self.message_process(message_cq)
elif isinstance(event, GroupRecallNoticeEvent) or isinstance(event, FriendRecallNoticeEvent):
user_info = UserInfo(
user_id=event.user_id,
user_nickname=get_user_nickname(event.user_id) or None,
user_cardname=get_user_cardname(event.user_id) or None,
platform="qq",
)
if isinstance(event, GroupRecallNoticeEvent):
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
else:
group_info = None
chat = await chat_manager.get_or_create_stream(
platform=user_info.platform, user_info=user_info, group_info=group_info
)
await self.storage.store_recalled_message(event.message_id, time.time(), chat)
async def handle_message(self, event: MessageEvent, bot: Bot) -> None:
"""处理收到的消息"""
self.bot = bot # 更新 bot 实例
# 用户屏蔽,不区分私聊/群聊
if event.user_id in global_config.ban_user_id:
return
if (
event.reply
and hasattr(event.reply, "sender")
and hasattr(event.reply.sender, "user_id")
and event.reply.sender.user_id in global_config.ban_user_id
):
logger.debug(f"跳过处理回复来自被ban用户 {event.reply.sender.user_id} 的消息")
return
# 处理私聊消息
if isinstance(event, PrivateMessageEvent):
if not global_config.enable_friend_chat: # 私聊过滤
return
else:
try:
user_info = UserInfo(
user_id=event.user_id,
user_nickname=(await bot.get_stranger_info(user_id=event.user_id, no_cache=True))["nickname"],
user_cardname=None,
platform="qq",
)
except Exception as e:
logger.error(f"获取陌生人信息失败: {e}")
return
logger.debug(user_info)
# group_info = GroupInfo(group_id=0, group_name="私聊", platform="qq")
group_info = None
# 处理群聊消息
else:
# 白名单设定由nontbot侧完成
if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
user_info = UserInfo(
user_id=event.user_id,
user_nickname=event.sender.nickname,
user_cardname=event.sender.card or None,
platform="qq",
)
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
# group_info = await bot.get_group_info(group_id=event.group_id)
# sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
message_cq = MessageRecvCQ(
message_id=event.message_id,
user_info=user_info,
raw_message=str(event.original_message),
group_info=group_info,
reply_message=event.reply,
platform="qq",
)
await self.message_process(message_cq)
async def handle_forward_message(self, event: MessageEvent, bot: Bot) -> None:
"""专用于处理合并转发的消息处理器"""
# 用户屏蔽,不区分私聊/群聊
if event.user_id in global_config.ban_user_id:
return
if isinstance(event, GroupMessageEvent):
if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
# 获取合并转发消息的详细信息
forward_info = await bot.get_forward_msg(message_id=event.message_id)
messages = forward_info["messages"]
# 构建合并转发消息的文本表示
processed_messages = []
for node in messages:
# 提取发送者昵称
nickname = node["sender"].get("nickname", "未知用户")
# 递归处理消息内容
message_content = await self.process_message_segments(node["message"], layer=0)
# 拼接为【昵称】+ 内容
processed_messages.append(f"{nickname}{message_content}")
# 组合所有消息
combined_message = "\n".join(processed_messages)
combined_message = f"合并转发消息内容:\n{combined_message}"
# 构建用户信息(使用转发消息的发送者)
user_info = UserInfo(
user_id=event.user_id,
user_nickname=event.sender.nickname,
user_cardname=event.sender.card if hasattr(event.sender, "card") else None,
platform="qq",
)
# 构建群聊信息(如果是群聊)
group_info = None
if isinstance(event, GroupMessageEvent):
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
# 创建消息对象
message_cq = MessageRecvCQ(
message_id=event.message_id,
user_info=user_info,
raw_message=combined_message,
group_info=group_info,
reply_message=event.reply,
platform="qq",
)
# 进入标准消息处理流程
await self.message_process(message_cq)
async def process_message_segments(self, segments: list, layer: int) -> str:
"""递归处理消息段"""
parts = []
for seg in segments:
part = await self.process_segment(seg, layer + 1)
parts.append(part)
return "".join(parts)
async def process_segment(self, seg: dict, layer: int) -> str:
"""处理单个消息段"""
seg_type = seg["type"]
if layer > 3:
# 防止有那种100层转发消息炸飞麦麦
return "【转发消息】"
if seg_type == "text":
return seg["data"]["text"]
elif seg_type == "image":
return "[图片]"
elif seg_type == "face":
return "[表情]"
elif seg_type == "at":
return f"@{seg['data'].get('qq', '未知用户')}"
elif seg_type == "forward":
# 递归处理嵌套的合并转发消息
nested_nodes = seg["data"].get("content", [])
nested_messages = []
nested_messages.append("合并转发消息内容:")
for node in nested_nodes:
nickname = node["sender"].get("nickname", "未知用户")
content = await self.process_message_segments(node["message"], layer=layer)
# nested_messages.append('-' * layer)
nested_messages.append(f"{'--' * layer}{nickname}{content}")
# nested_messages.append(f"{'--' * layer}合并转发第【{layer}】层结束")
return "\n".join(nested_messages)
else:
return f"[{seg_type}]"
# 创建全局ChatBot实例 # 创建全局ChatBot实例
chat_bot = ChatBot() chat_bot = ChatBot()

View File

@@ -6,7 +6,7 @@ from typing import Dict, Optional
from ...common.database import db from ...common.database import db
from .message_base import GroupInfo, UserInfo from ..message.message_base import GroupInfo, UserInfo
from src.common.logger import get_module_logger from src.common.logger import get_module_logger

View File

@@ -1,385 +0,0 @@
import base64
import html
import asyncio
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import ssl
import os
import aiohttp
from src.common.logger import get_module_logger
from nonebot import get_driver
from ..models.utils_model import LLM_request
from .config import global_config
from .mapper import emojimapper
from .message_base import Seg
from .utils_user import get_user_nickname, get_groupname
from .message_base import GroupInfo, UserInfo
driver = get_driver()
config = driver.config
# 创建SSL上下文
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("AES128-GCM-SHA256")
logger = get_module_logger("cq_code")
@dataclass
class CQCode:
"""
CQ码数据类用于存储和处理CQ码
属性:
type: CQ码类型'image', 'at', 'face'等)
params: CQ码的参数字典
raw_code: 原始CQ码字符串
translated_segments: 经过处理后的Seg对象列表
"""
type: str
params: Dict[str, str]
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
translated_segments: Optional[Union[Seg, List[Seg]]] = None
reply_message: Dict = None # 存储回复消息
image_base64: Optional[str] = None
_llm: Optional[LLM_request] = None
def __post_init__(self):
"""初始化LLM实例"""
pass
async def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == "text":
self.translated_segments = Seg(type="text", data=self.params.get("text", ""))
elif self.type == "image":
base64_data = await self.translate_image()
if base64_data:
if self.params.get("sub_type") == "0":
self.translated_segments = Seg(type="image", data=base64_data)
else:
self.translated_segments = Seg(type="emoji", data=base64_data)
else:
self.translated_segments = Seg(type="text", data="[图片]")
elif self.type == "at":
if self.params.get("qq") == "all":
self.translated_segments = Seg(type="text", data="@[全体成员]")
else:
user_nickname = get_user_nickname(self.params.get("qq", ""))
self.translated_segments = Seg(type="text", data=f"[@{user_nickname or '某人'}]")
elif self.type == "reply":
reply_segments = await self.translate_reply()
if reply_segments:
self.translated_segments = Seg(type="seglist", data=reply_segments)
else:
self.translated_segments = Seg(type="text", data="[回复某人消息]")
elif self.type == "face":
face_id = self.params.get("id", "")
self.translated_segments = Seg(type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]")
elif self.type == "forward":
forward_segments = await self.translate_forward()
if forward_segments:
self.translated_segments = Seg(type="seglist", data=forward_segments)
else:
self.translated_segments = Seg(type="text", data="[转发消息]")
else:
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
async def get_img(self) -> Optional[str]:
"""异步获取图片并转换为base64"""
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/50.0.2661.87 Safari/537.36",
"Accept": "text/html, application/xhtml xml, */*",
"Accept-Encoding": "gbk, GB2312",
"Accept-Language": "zh-cn",
"Content-Type": "application/x-www-form-urlencoded",
"Cache-Control": "no-cache",
}
url = html.unescape(self.params["url"])
if not url.startswith(("http://", "https://")):
return None
max_retries = 3
for retry in range(max_retries):
try:
logger.debug(f"获取图片中: {url}")
# 设置SSL上下文和创建连接器
conn = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=conn) as session:
async with session.get(
url,
headers=headers,
timeout=aiohttp.ClientTimeout(total=15),
allow_redirects=True,
) as response:
# 腾讯服务器特殊状态码处理
if response.status == 400 and "multimedia.nt.qq.com.cn" in url:
return None
if response.status != 200:
raise aiohttp.ClientError(f"HTTP {response.status}")
# 验证内容类型
content_type = response.headers.get("Content-Type", "")
if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}")
# 读取响应内容
content = await response.read()
logger.debug(f"获取图片成功: {url}")
# 转换为Base64
image_base64 = base64.b64encode(content).decode("utf-8")
self.image_base64 = image_base64
return image_base64
except (aiohttp.ClientError, ValueError) as e:
if retry == max_retries - 1:
logger.error(f"最终请求失败: {str(e)}")
await asyncio.sleep(1.5**retry) # 指数退避
except Exception as e:
logger.exception(f"获取图片时发生未知错误: {str(e)}")
return None
return None
async def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串"""
if "url" not in self.params:
return None
return await self.get_img()
async def translate_forward(self) -> Optional[List[Seg]]:
"""处理转发消息返回Seg列表"""
try:
if "content" not in self.params:
return None
content = self.unescape(self.params["content"])
import ast
try:
messages = ast.literal_eval(content)
except ValueError as e:
logger.error(f"解析转发消息内容失败: {str(e)}")
return None
formatted_segments = []
for msg in messages:
sender = msg.get("sender", {})
nickname = sender.get("card") or sender.get("nickname", "未知用户")
raw_message = msg.get("raw_message", "")
message_array = msg.get("message", [])
if message_array and isinstance(message_array, list):
for message_part in message_array:
if message_part.get("type") == "forward":
content_seg = Seg(type="text", data="[转发消息]")
break
else:
if raw_message:
from .message_cq import MessageRecvCQ
user_info = UserInfo(
platform="qq",
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info = GroupInfo(
platform="qq",
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0)),
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
user_info=user_info,
raw_message=raw_message,
plain_text=raw_message,
group_info=group_info,
)
await message_obj.initialize()
content_seg = Seg(type="seglist", data=[message_obj.message_segment])
else:
content_seg = Seg(type="text", data="[空消息]")
else:
if raw_message:
from .message_cq import MessageRecvCQ
user_info = UserInfo(
platform="qq",
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info = GroupInfo(
platform="qq",
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0)),
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
user_info=user_info,
raw_message=raw_message,
plain_text=raw_message,
group_info=group_info,
)
await message_obj.initialize()
content_seg = Seg(type="seglist", data=[message_obj.message_segment])
else:
content_seg = Seg(type="text", data="[空消息]")
formatted_segments.append(Seg(type="text", data=f"{nickname}: "))
formatted_segments.append(content_seg)
formatted_segments.append(Seg(type="text", data="\n"))
return formatted_segments
except Exception as e:
logger.error(f"处理转发消息失败: {str(e)}")
return None
async def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ
if self.reply_message is None:
return None
if hasattr(self.reply_message, "group_id"):
group_info = GroupInfo(platform="qq", group_id=self.reply_message.group_id, group_name="")
else:
group_info = None
if self.reply_message.sender.user_id:
message_obj = MessageRecvCQ(
user_info=UserInfo(
user_id=self.reply_message.sender.user_id, user_nickname=self.reply_message.sender.nickname
),
message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message),
group_info=group_info,
)
await message_obj.initialize()
segments = []
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
segments.append(Seg(type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "))
else:
segments.append(
Seg(
type="text",
data=f"[回复 {self.reply_message.sender.nickname} 的消息: ",
)
)
segments.append(Seg(type="seglist", data=[message_obj.message_segment]))
segments.append(Seg(type="text", data="]"))
return segments
else:
return None
@staticmethod
def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符"""
return text.replace("&#44;", ",").replace("&#91;", "[").replace("&#93;", "]").replace("&amp;", "&")
class CQCode_tool:
@staticmethod
def cq_from_dict_to_class(cq_code: Dict, msg, reply: Optional[Dict] = None) -> CQCode:
"""
将CQ码字典转换为CQCode对象
Args:
cq_code: CQ码字典
msg: MessageCQ对象
reply: 回复消息的字典(可选)
Returns:
CQCode对象
"""
# 处理字典形式的CQ码
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
cq_type = cq_code.get("type", "text")
params = {}
if cq_type == "text":
params["text"] = cq_code.get("data", {}).get("text", "")
else:
params = cq_code.get("data", {})
instance = CQCode(
type=cq_type,
params=params,
group_info=msg.message_info.group_info,
user_info=msg.message_info.user_info,
reply_message=reply,
)
return instance
@staticmethod
def create_reply_cq(message_id: int) -> str:
"""
创建回复CQ码
Args:
message_id: 回复的消息ID
Returns:
回复CQ码字符串
"""
return f"[CQ:reply,id={message_id}]"
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = abs_path.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;").replace(",", "&#44;")
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@staticmethod
def create_emoji_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;").replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
@staticmethod
def create_image_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&amp;").replace("[", "&#91;").replace("]", "&#93;").replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
cq_code_tool = CQCode_tool()

View File

@@ -9,8 +9,6 @@ from typing import Optional, Tuple
from PIL import Image from PIL import Image
import io import io
from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..chat.config import global_config
from ..chat.utils import get_embedding from ..chat.utils import get_embedding
@@ -21,8 +19,6 @@ from src.common.logger import get_module_logger
logger = get_module_logger("emoji") logger = get_module_logger("emoji")
driver = get_driver()
config = driver.config
image_manager = ImageManager() image_manager = ImageManager()
@@ -118,9 +114,11 @@ class EmojiManager:
try: try:
# 获取所有表情包 # 获取所有表情包
all_emojis = [e for e in all_emojis = [
db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1, "blacklist": 1}) e
if 'blacklist' not in e] for e in db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1, "blacklist": 1})
if "blacklist" not in e
]
if not all_emojis: if not all_emojis:
logger.warning("数据库中没有任何表情包") logger.warning("数据库中没有任何表情包")

View File

@@ -2,7 +2,6 @@ import random
import time import time
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
@@ -21,9 +20,6 @@ llm_config = LogConfig(
logger = get_module_logger("llm_generator", config=llm_config) logger = get_module_logger("llm_generator", config=llm_config)
driver = get_driver()
config = driver.config
class ResponseGenerator: class ResponseGenerator:
def __init__(self): def __init__(self):

View File

@@ -9,7 +9,7 @@ import urllib3
from .utils_image import image_manager from .utils_image import image_manager
from .message_base import Seg, UserInfo, BaseMessageInfo, MessageBase from ..message.message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream from .chat_stream import ChatStream
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
@@ -75,19 +75,6 @@ class MessageRecv(Message):
""" """
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
message_segment = message_dict.get("message_segment", {})
if message_segment.get("data", "") == "[json]":
# 提取json消息中的展示信息
pattern = r"\[CQ:json,data=(?P<json_data>.+?)\]"
match = re.search(pattern, message_dict.get("raw_message", ""))
raw_json = html.unescape(match.group("json_data"))
try:
json_message = json.loads(raw_json)
except json.JSONDecodeError:
json_message = {}
message_segment["data"] = json_message.get("prompt", "")
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message") self.raw_message = message_dict.get("raw_message")

View File

@@ -1,170 +0,0 @@
import time
from dataclasses import dataclass
from typing import Dict, Optional
import urllib3
from .cq_code import cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# 这个类是消息数据类,用于存储和管理消息数据。
# 它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
# 它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageCQ(MessageBase):
"""QQ消息基类继承自MessageBase
最小必要参数:
- message_id: 消息ID
- user_id: 发送者/接收者ID
- platform: 平台标识(默认为"qq"
"""
def __init__(
self, message_id: int, user_info: UserInfo, group_info: Optional[GroupInfo] = None, platform: str = "qq"
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform, message_id=message_id, time=int(time.time()), group_info=group_info, user_info=user_info
)
# 调用父类初始化message_segment 由子类设置
super().__init__(message_info=message_info, message_segment=None, raw_message=None)
@dataclass
class MessageRecvCQ(MessageCQ):
"""QQ接收消息类用于解析raw_message到Seg对象"""
def __init__(
self,
message_id: int,
user_info: UserInfo,
raw_message: str,
group_info: Optional[GroupInfo] = None,
platform: str = "qq",
reply_message: Optional[Dict] = None,
):
# 调用父类初始化
super().__init__(message_id, user_info, group_info, platform)
# 私聊消息不携带group_info
if group_info is None:
pass
elif group_info.group_name is None:
group_info.group_name = get_groupname(group_info.group_id)
# 解析消息段
self.message_segment = None # 初始化为None
self.raw_message = raw_message
# 异步初始化在外部完成
# 添加对reply的解析
self.reply_message = reply_message
async def initialize(self):
"""异步初始化方法"""
self.message_segment = await self._parse_message(self.raw_message, self.reply_message)
async def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
"""异步解析消息内容为Seg对象"""
cq_code_dict_list = []
segments = []
start = 0
while True:
cq_start = message.find("[CQ:", start)
if cq_start == -1:
if start < len(message):
text = message[start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
if cq_start > start:
text = message[start:cq_start].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
cq_end = message.find("]", cq_start)
if cq_end == -1:
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start : cq_end + 1]
cq_code_dict_list.append(parse_cq_code(cq_code))
start = cq_end + 1
# 转换CQ码为Seg对象
for code_item in cq_code_dict_list:
cq_code_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message)
await cq_code_obj.translate() # 异步调用translate
if cq_code_obj.translated_segments:
segments.append(cq_code_obj.translated_segments)
# 如果只有一个segment直接返回
if len(segments) == 1:
return segments[0]
# 否则返回seglist类型的Seg
return Seg(type="seglist", data=segments)
def to_dict(self) -> Dict:
"""转换为字典格式,包含所有必要信息"""
base_dict = super().to_dict()
return base_dict
@dataclass
class MessageSendCQ(MessageCQ):
"""QQ发送消息类用于将Seg对象转换为raw_message"""
def __init__(self, data: Dict):
# 调用父类初始化
message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
message_segment = Seg.from_dict(data.get("message_segment", {}))
super().__init__(
message_info.message_id,
message_info.user_info,
message_info.group_info if message_info.group_info else None,
message_info.platform,
)
self.message_segment = message_segment
self.raw_message = self._generate_raw_message()
def _generate_raw_message(self) -> str:
"""将Seg对象转换为raw_message"""
segments = []
# 处理消息段
if self.message_segment.type == "seglist":
for seg in self.message_segment.data:
segments.append(self._seg_to_cq_code(seg))
else:
segments.append(self._seg_to_cq_code(self.message_segment))
return "".join(segments)
def _seg_to_cq_code(self, seg: Seg) -> str:
"""将单个Seg对象转换为CQ码字符串"""
if seg.type == "text":
return str(seg.data)
elif seg.type == "image":
return cq_code_tool.create_image_cq_base64(seg.data)
elif seg.type == "emoji":
return cq_code_tool.create_emoji_cq_base64(seg.data)
elif seg.type == "at":
return f"[CQ:at,qq={seg.data}]"
elif seg.type == "reply":
return cq_code_tool.create_reply_cq(int(seg.data))
else:
return f"[{seg.data}]"

View File

@@ -3,9 +3,8 @@ import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from nonebot.adapters.onebot.v11 import Bot
from ...common.database import db from ...common.database import db
from .message_cq import MessageSendCQ from ..message.api import global_api
from .message import MessageSending, MessageThinking, MessageSet from .message import MessageSending, MessageThinking, MessageSet
from .storage import MessageStorage from .storage import MessageStorage
@@ -32,9 +31,9 @@ class Message_Sender:
self.last_send_time = 0 self.last_send_time = 0
self._current_bot = None self._current_bot = None
def set_bot(self, bot: Bot): def set_bot(self, bot):
"""设置当前bot实例""" """设置当前bot实例"""
self._current_bot = bot pass
def get_recalled_messages(self, stream_id: str) -> list: def get_recalled_messages(self, stream_id: str) -> list:
"""获取所有撤回的消息""" """获取所有撤回的消息"""
@@ -60,31 +59,14 @@ class Message_Sender:
break break
if not is_recalled: if not is_recalled:
message_json = message.to_dict() message_json = message.to_dict()
message_send = MessageSendCQ(data=message_json)
message_preview = truncate_message(message.processed_plain_text) message_preview = truncate_message(message.processed_plain_text)
if message_send.message_info.group_info and message_send.message_info.group_info.group_id: try:
try: result = await global_api.send_message("http://127.0.0.1:18002/api/message", message_json)
await self._current_bot.send_group_msg( if result["status"] == "success":
group_id=message.message_info.group_info.group_id,
message=message_send.raw_message,
auto_escape=False,
)
logger.success(f"发送消息“{message_preview}”成功") logger.success(f"发送消息“{message_preview}”成功")
except Exception as e: except Exception as e:
logger.error(f"[调试] 发生错误 {e}") logger.error(f"发送消息“{message_preview}”失败: {str(e)}")
logger.error(f"[调试] 发送消息“{message_preview}”失败")
else:
try:
logger.debug(message.message_info.user_info)
await self._current_bot.send_private_msg(
user_id=message.sender_info.user_id,
message=message_send.raw_message,
auto_escape=False,
)
logger.success(f"发送消息“{message_preview}”成功")
except Exception as e:
logger.error(f"[调试] 发生错误 {e}")
logger.error(f"[调试] 发送消息“{message_preview}”失败")
class MessageContainer: class MessageContainer:

View File

@@ -3,7 +3,7 @@ from typing import Optional
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ...common.database import db from ...common.database import db
from .message_base import UserInfo from ..message.message_base import UserInfo
from .chat_stream import ChatStream from .chat_stream import ChatStream
import math import math
from bson.decimal128 import Decimal128 from bson.decimal128 import Decimal128
@@ -122,11 +122,15 @@ class RelationshipManager:
relationship.relationship_value = float(relationship.relationship_value.to_decimal()) relationship.relationship_value = float(relationship.relationship_value.to_decimal())
else: else:
relationship.relationship_value = float(relationship.relationship_value) relationship.relationship_value = float(relationship.relationship_value)
logger.info(f"[关系管理] 用户 {user_id}({platform}) 的关系值已转换为double类型: {relationship.relationship_value}") logger.info(
f"[关系管理] 用户 {user_id}({platform}) 的关系值已转换为double类型: {relationship.relationship_value}"
)
except (ValueError, TypeError): except (ValueError, TypeError):
# 如果不能解析/强转则将relationship.relationship_value设置为double类型的0 # 如果不能解析/强转则将relationship.relationship_value设置为double类型的0
relationship.relationship_value = 0.0 relationship.relationship_value = 0.0
logger.warning(f"[关系管理] 用户 {user_id}({platform}) 的关系值无法转换为double类型已设置为0") logger.warning(
f"[关系管理] 用户 {user_id}({platform}) 的关系值无法转换为double类型已设置为0"
)
relationship.relationship_value += value relationship.relationship_value += value
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
relationship.saved = True relationship.saved = True

View File

@@ -1,6 +1,5 @@
from typing import List, Optional from typing import List, Optional
from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
@@ -15,9 +14,6 @@ topic_config = LogConfig(
logger = get_module_logger("topic_identifier", config=topic_config) logger = get_module_logger("topic_identifier", config=topic_config)
driver = get_driver()
config = driver.config
class TopicIdentifier: class TopicIdentifier:
def __init__(self): def __init__(self):

View File

@@ -7,20 +7,17 @@ from typing import Dict, List
import jieba import jieba
import numpy as np import numpy as np
from nonebot import get_driver
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config from .config import global_config
from .message import MessageRecv, Message from .message import MessageRecv, Message
from .message_base import UserInfo from ..message.message_base import UserInfo
from .chat_stream import ChatStream from .chat_stream import ChatStream
from ..moods.moods import MoodManager from ..moods.moods import MoodManager
from ...common.database import db from ...common.database import db
driver = get_driver()
config = driver.config
logger = get_module_logger("chat_utils") logger = get_module_logger("chat_utils")
@@ -291,7 +288,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
for sentence in sentences: for sentence in sentences:
parts = sentence.split("") parts = sentence.split("")
current_sentence = parts[0] current_sentence = parts[0]
if not is_western_paragraph(current_sentence): if not is_western_paragraph(current_sentence):
for part in parts[1:]: for part in parts[1:]:
if random.random() < split_strength: if random.random() < split_strength:
new_sentences.append(current_sentence.strip()) new_sentences.append(current_sentence.strip())
@@ -323,7 +320,7 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
for sentence in sentences: for sentence in sentences:
sentence = sentence.rstrip(",") sentence = sentence.rstrip(",")
# 西文字符句子不进行随机合并 # 西文字符句子不进行随机合并
if not is_western_paragraph(current_sentence): if not is_western_paragraph(current_sentence):
if random.random() < split_strength * 0.5: if random.random() < split_strength * 0.5:
sentence = sentence.replace("", "").replace(",", "") sentence = sentence.replace("", "").replace(",", "")
elif random.random() < split_strength: elif random.random() < split_strength:
@@ -364,10 +361,10 @@ def random_remove_punctuation(text: str) -> str:
def process_llm_response(text: str) -> List[str]: def process_llm_response(text: str) -> List[str]:
# processed_response = process_text_with_typos(content) # processed_response = process_text_with_typos(content)
# 对西文字符段落的回复长度设置为汉字字符的两倍 # 对西文字符段落的回复长度设置为汉字字符的两倍
if len(text) > 100 and not is_western_paragraph(text) : if len(text) > 100 and not is_western_paragraph(text):
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复") logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
return ["懒得说"] return ["懒得说"]
elif len(text) > 200 : elif len(text) > 200:
logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复") logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
return ["懒得说"] return ["懒得说"]
# 处理长消息 # 处理长消息
@@ -530,12 +527,12 @@ def recover_kaomoji(sentences, placeholder_to_kaomoji):
recovered_sentences.append(sentence) recovered_sentences.append(sentence)
return recovered_sentences return recovered_sentences
def is_western_char(char): def is_western_char(char):
"""检测是否为西文字符""" """检测是否为西文字符"""
return len(char.encode('utf-8')) <= 2 return len(char.encode("utf-8")) <= 2
def is_western_paragraph(paragraph): def is_western_paragraph(paragraph):
"""检测是否为西文字符段落""" """检测是否为西文字符段落"""
return all(is_western_char(char) for char in paragraph if char.isalnum()) return all(is_western_char(char) for char in paragraph if char.isalnum())

View File

@@ -6,7 +6,6 @@ from typing import Optional
from PIL import Image from PIL import Image
import io import io
from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..chat.config import global_config
@@ -16,9 +15,6 @@ from src.common.logger import get_module_logger
logger = get_module_logger("chat_image") logger = get_module_logger("chat_image")
driver = get_driver()
config = driver.config
class ImageManager: class ImageManager:
_instance = None _instance = None

View File

@@ -1,11 +1 @@
from nonebot import get_app
from .api import router
from src.common.logger import get_module_logger
# 获取主应用实例并挂载路由
app = get_app()
app.include_router(router, prefix="/api")
# 打印日志方便确认API已注册
logger = get_module_logger("cfg_reload")
logger.success("配置重载API已注册可通过 /api/reload-config 访问")

View File

@@ -8,7 +8,6 @@ import re
import jieba import jieba
import networkx as nx import networkx as nx
from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..chat.config import global_config
from ..chat.utils import ( from ..chat.utils import (
@@ -232,13 +231,13 @@ class Hippocampus:
# 创建双峰分布的记忆调度器 # 创建双峰分布的记忆调度器
scheduler = MemoryBuildScheduler( scheduler = MemoryBuildScheduler(
n_hours1=global_config.memory_build_distribution[0], # 第一个分布均值4小时前 n_hours1=global_config.memory_build_distribution[0], # 第一个分布均值4小时前
std_hours1=global_config.memory_build_distribution[1], # 第一个分布标准差 std_hours1=global_config.memory_build_distribution[1], # 第一个分布标准差
weight1=global_config.memory_build_distribution[2], # 第一个分布权重 60% weight1=global_config.memory_build_distribution[2], # 第一个分布权重 60%
n_hours2=global_config.memory_build_distribution[3], # 第二个分布均值24小时前 n_hours2=global_config.memory_build_distribution[3], # 第二个分布均值24小时前
std_hours2=global_config.memory_build_distribution[4], # 第二个分布标准差 std_hours2=global_config.memory_build_distribution[4], # 第二个分布标准差
weight2=global_config.memory_build_distribution[5], # 第二个分布权重 40% weight2=global_config.memory_build_distribution[5], # 第二个分布权重 40%
total_samples=global_config.build_memory_sample_num # 总共生成10个时间点 total_samples=global_config.build_memory_sample_num, # 总共生成10个时间点
) )
# 生成时间戳数组 # 生成时间戳数组
@@ -250,9 +249,7 @@ class Hippocampus:
chat_samples = [] chat_samples = []
for timestamp in timestamps: for timestamp in timestamps:
messages = self.random_get_msg_snippet( messages = self.random_get_msg_snippet(
timestamp, timestamp, global_config.build_memory_sample_length, max_memorized_time_per_msg
global_config.build_memory_sample_length,
max_memorized_time_per_msg
) )
if messages: if messages:
time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600
@@ -297,16 +294,16 @@ class Hippocampus:
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num)) topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num))
# 使用正则表达式提取<>中的内容 # 使用正则表达式提取<>中的内容
topics = re.findall(r'<([^>]+)>', topics_response[0]) topics = re.findall(r"<([^>]+)>", topics_response[0])
# 如果没有找到<>包裹的内容,返回['none'] # 如果没有找到<>包裹的内容,返回['none']
if not topics: if not topics:
topics = ['none'] topics = ["none"]
else: else:
# 处理提取出的话题 # 处理提取出的话题
topics = [ topics = [
topic.strip() topic.strip()
for topic in ','.join(topics).replace("", ",").replace("", ",").replace(" ", ",").split(",") for topic in ",".join(topics).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if topic.strip() if topic.strip()
] ]
@@ -314,8 +311,7 @@ class Hippocampus:
# any()检查topic中是否包含任何一个filter_keywords中的关键词 # any()检查topic中是否包含任何一个filter_keywords中的关键词
# 只保留不包含禁用关键词的topic # 只保留不包含禁用关键词的topic
filtered_topics = [ filtered_topics = [
topic for topic in topics topic for topic in topics if not any(keyword in topic for keyword in global_config.memory_ban_words)
if not any(keyword in topic for keyword in global_config.memory_ban_words)
] ]
logger.debug(f"过滤后话题: {filtered_topics}") logger.debug(f"过滤后话题: {filtered_topics}")
@@ -331,14 +327,14 @@ class Hippocampus:
# 初始化压缩后的记忆集合和相似主题字典 # 初始化压缩后的记忆集合和相似主题字典
compressed_memory = set() # 存储压缩后的(主题,内容)元组 compressed_memory = set() # 存储压缩后的(主题,内容)元组
similar_topics_dict = {} # 存储每个话题的相似主题列表 similar_topics_dict = {} # 存储每个话题的相似主题列表
# 遍历每个主题及其对应的LLM任务 # 遍历每个主题及其对应的LLM任务
for topic, task in tasks: for topic, task in tasks:
response = await task response = await task
if response: if response:
# 将主题和LLM生成的内容添加到压缩记忆中 # 将主题和LLM生成的内容添加到压缩记忆中
compressed_memory.add((topic, response[0])) compressed_memory.add((topic, response[0]))
# 为当前主题寻找相似的已存在主题 # 为当前主题寻找相似的已存在主题
existing_topics = list(self.memory_graph.G.nodes()) existing_topics = list(self.memory_graph.G.nodes())
similar_topics = [] similar_topics = []
@@ -404,7 +400,7 @@ class Hippocampus:
logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}") logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}")
all_added_nodes.extend(topic for topic, _ in compressed_memory) all_added_nodes.extend(topic for topic, _ in compressed_memory)
# all_connected_nodes.extend(topic for topic, _ in similar_topics_dict) # all_connected_nodes.extend(topic for topic, _ in similar_topics_dict)
for topic, memory in compressed_memory: for topic, memory in compressed_memory:
self.memory_graph.add_dot(topic, memory) self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) all_topics.append(topic)
@@ -415,13 +411,13 @@ class Hippocampus:
for similar_topic, similarity in similar_topics: for similar_topic, similarity in similar_topics:
if topic != similar_topic: if topic != similar_topic:
strength = int(similarity * 10) strength = int(similarity * 10)
logger.debug(f"连接相似节点: {topic}{similar_topic} (强度: {strength})") logger.debug(f"连接相似节点: {topic}{similar_topic} (强度: {strength})")
all_added_edges.append(f"{topic}-{similar_topic}") all_added_edges.append(f"{topic}-{similar_topic}")
all_connected_nodes.append(topic) all_connected_nodes.append(topic)
all_connected_nodes.append(similar_topic) all_connected_nodes.append(similar_topic)
self.memory_graph.G.add_edge( self.memory_graph.G.add_edge(
topic, topic,
similar_topic, similar_topic,
@@ -442,11 +438,10 @@ class Hippocampus:
logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}")
# logger.success(f"强化连接: {', '.join(all_added_edges)}") # logger.success(f"强化连接: {', '.join(all_added_edges)}")
self.sync_memory_to_db() self.sync_memory_to_db()
end_time = time.time() end_time = time.time()
logger.success( logger.success(
f"--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} " f"--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} 秒--------------------------"
"秒--------------------------"
) )
def sync_memory_to_db(self): def sync_memory_to_db(self):
@@ -800,16 +795,16 @@ class Hippocampus:
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 4)) topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 4))
# 使用正则表达式提取<>中的内容 # 使用正则表达式提取<>中的内容
print(f"话题: {topics_response[0]}") print(f"话题: {topics_response[0]}")
topics = re.findall(r'<([^>]+)>', topics_response[0]) topics = re.findall(r"<([^>]+)>", topics_response[0])
# 如果没有找到<>包裹的内容,返回['none'] # 如果没有找到<>包裹的内容,返回['none']
if not topics: if not topics:
topics = ['none'] topics = ["none"]
else: else:
# 处理提取出的话题 # 处理提取出的话题
topics = [ topics = [
topic.strip() topic.strip()
for topic in ','.join(topics).replace("", ",").replace("", ",").replace(" ", ",").split(",") for topic in ",".join(topics).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if topic.strip() if topic.strip()
] ]
@@ -885,7 +880,7 @@ class Hippocampus:
# 识别主题 # 识别主题
identified_topics = await self._identify_topics(text) identified_topics = await self._identify_topics(text)
print(f"识别主题: {identified_topics}") print(f"识别主题: {identified_topics}")
if identified_topics[0] == "none": if identified_topics[0] == "none":
return 0 return 0
@@ -946,7 +941,7 @@ class Hippocampus:
# 计算最终激活值 # 计算最终激活值
activation = int((topic_match + average_similarities) / 2 * 100) activation = int((topic_match + average_similarities) / 2 * 100)
logger.info(f"识别<{text[:15]}...>主题: {identified_topics}, 匹配率: {topic_match:.3f}, 激活值: {activation}") logger.info(f"识别<{text[:15]}...>主题: {identified_topics}, 匹配率: {topic_match:.3f}, 激活值: {activation}")
return activation return activation
@@ -994,9 +989,6 @@ def segment_text(text):
return seg_text return seg_text
driver = get_driver()
config = driver.config
start_time = time.time() start_time = time.time()
# 创建记忆图 # 创建记忆图

View File

@@ -0,0 +1,26 @@
"""Maim Message - A message handling library"""
__version__ = "0.1.0"
from .api import BaseMessageAPI, global_api
from .message_base import (
Seg,
GroupInfo,
UserInfo,
FormatInfo,
TemplateInfo,
BaseMessageInfo,
MessageBase,
)
__all__ = [
"BaseMessageAPI",
"Seg",
"global_api",
"GroupInfo",
"UserInfo",
"FormatInfo",
"TemplateInfo",
"BaseMessageInfo",
"MessageBase",
]

View File

@@ -0,0 +1,86 @@
from fastapi import FastAPI, HTTPException
from typing import Optional, Dict, Any, Callable, List
import aiohttp
import asyncio
import uvicorn
import os
class BaseMessageAPI:
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
self.app = FastAPI()
self.host = host
self.port = port
self.message_handlers: List[Callable] = []
self._setup_routes()
self._running = False
def _setup_routes(self):
"""设置基础路由"""
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
# try:
for handler in self.message_handlers:
await handler(message)
return {"status": "success"}
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e)) from e
def register_message_handler(self, handler: Callable):
"""注册消息处理函数"""
self.message_handlers.append(handler)
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
return await response.json()
except Exception as e:
# logger.error(f"发送消息失败: {str(e)}")
pass
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
await self.server.serve()
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def start(self):
"""启动服务器的便捷方法"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.start_server())
except KeyboardInterrupt:
pass
finally:
loop.close()
global_api = BaseMessageAPI(host=os.environ["HOST"], port=os.environ["PORT"])

View File

@@ -103,6 +103,63 @@ class UserInfo:
) )
@dataclass
class FormatInfo:
"""格式信息类"""
"""
目前maimcore可接受的格式为text,image,emoji
可发送的格式为text,emoji,reply
"""
content_format: Optional[str] = None
accept_format: Optional[str] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "FormatInfo":
"""从字典创建FormatInfo实例
Args:
data: 包含必要字段的字典
Returns:
FormatInfo: 新的实例
"""
return cls(
content_format=data.get("content_format"),
accept_format=data.get("accept_format"),
)
@dataclass
class TemplateInfo:
"""模板信息类"""
template_items: Optional[List[Dict]] = None
template_name: Optional[str] = None
template_default: bool = True
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "TemplateInfo":
"""从字典创建TemplateInfo实例
Args:
data: 包含必要字段的字典
Returns:
TemplateInfo: 新的实例
"""
return cls(
template_items=data.get("template_items"),
template_name=data.get("template_name"),
template_default=data.get("template_default", True),
)
@dataclass @dataclass
class BaseMessageInfo: class BaseMessageInfo:
"""消息信息类""" """消息信息类"""
@@ -112,13 +169,15 @@ class BaseMessageInfo:
time: Optional[int] = None time: Optional[int] = None
group_info: Optional[GroupInfo] = None group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None user_info: Optional[UserInfo] = None
format_info: Optional[FormatInfo] = None
template_info: Optional[TemplateInfo] = None
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
"""转换为字典格式""" """转换为字典格式"""
result = {} result = {}
for field, value in asdict(self).items(): for field, value in asdict(self).items():
if value is not None: if value is not None:
if isinstance(value, (GroupInfo, UserInfo)): if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)):
result[field] = value.to_dict() result[field] = value.to_dict()
else: else:
result[field] = value result[field] = value
@@ -136,12 +195,16 @@ class BaseMessageInfo:
""" """
group_info = GroupInfo.from_dict(data.get("group_info", {})) group_info = GroupInfo.from_dict(data.get("group_info", {}))
user_info = UserInfo.from_dict(data.get("user_info", {})) user_info = UserInfo.from_dict(data.get("user_info", {}))
format_info = FormatInfo.from_dict(data.get("format_info", {}))
template_info = TemplateInfo.from_dict(data.get("template_info", {}))
return cls( return cls(
platform=data.get("platform"), platform=data.get("platform"),
message_id=data.get("message_id"), message_id=data.get("message_id"),
time=data.get("time"), time=data.get("time"),
group_info=group_info, group_info=group_info,
user_info=user_info, user_info=user_info,
format_info=format_info,
template_info=template_info,
) )

View File

@@ -0,0 +1,98 @@
import unittest
import asyncio
import aiohttp
from api import BaseMessageAPI
from message_base import (
BaseMessageInfo,
UserInfo,
GroupInfo,
FormatInfo,
TemplateInfo,
MessageBase,
Seg,
)
send_url = "http://localhost"
receive_port = 18002 # 接收消息的端口
send_port = 18000 # 发送消息的端口
test_endpoint = "/api/message"
# 创建并启动API实例
api = BaseMessageAPI(host="0.0.0.0", port=receive_port)
class TestLiveAPI(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
"""测试前的设置"""
self.received_messages = []
async def message_handler(message):
self.received_messages.append(message)
self.api = api
self.api.register_message_handler(message_handler)
self.server_task = asyncio.create_task(self.api.run())
try:
await asyncio.wait_for(asyncio.sleep(1), timeout=5)
except asyncio.TimeoutError:
self.skipTest("服务器启动超时")
async def asyncTearDown(self):
"""测试后的清理"""
if hasattr(self, "server_task"):
await self.api.stop() # 先调用正常的停止流程
if not self.server_task.done():
self.server_task.cancel()
try:
await asyncio.wait_for(self.server_task, timeout=100)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
async def test_send_and_receive_message(self):
"""测试向运行中的API发送消息并接收响应"""
# 准备测试消息
user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq")
group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq")
format_info = FormatInfo(
content_format=["text"], accept_format=["text", "emoji", "reply"]
)
template_info = None
message_info = BaseMessageInfo(
platform="qq",
message_id=12345678,
time=12345678,
group_info=group_info,
user_info=user_info,
format_info=format_info,
template_info=template_info,
)
message = MessageBase(
message_info=message_info,
raw_message="测试消息",
message_segment=Seg(type="text", data="测试消息"),
)
test_message = message.to_dict()
# 发送测试消息到发送端口
async with aiohttp.ClientSession() as session:
async with session.post(
f"{send_url}:{send_port}{test_endpoint}",
json=test_message,
) as response:
response_data = await response.json()
self.assertEqual(response.status, 200)
self.assertEqual(response_data["status"], "success")
try:
async with asyncio.timeout(5): # 设置5秒超时
while len(self.received_messages) == 0:
await asyncio.sleep(0.1)
received_message = self.received_messages[0]
print(received_message)
self.received_messages.clear()
except asyncio.TimeoutError:
self.fail("等待接收消息超时")
if __name__ == "__main__":
unittest.main()

View File

@@ -6,15 +6,13 @@ from typing import Tuple, Union
import aiohttp import aiohttp
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from nonebot import get_driver
import base64 import base64
from PIL import Image from PIL import Image
import io import io
import os
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..chat.config import global_config
driver = get_driver()
config = driver.config
logger = get_module_logger("model_utils") logger = get_module_logger("model_utils")
@@ -34,8 +32,9 @@ class LLM_request:
def __init__(self, model, **kwargs): def __init__(self, model, **kwargs):
# 将大写的配置键转换为小写并从config中获取实际值 # 将大写的配置键转换为小写并从config中获取实际值
try: try:
self.api_key = getattr(config, model["key"]) self.api_key = os.environ[model["key"]]
self.base_url = getattr(config, model["base_url"]) self.base_url = os.environ[model["base_url"]]
print(self.api_key, self.base_url)
except AttributeError as e: except AttributeError as e:
logger.error(f"原始 model dict 信息:{model}") logger.error(f"原始 model dict 信息:{model}")
logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")

View File

@@ -3,7 +3,6 @@ import json
import re import re
from typing import Dict, Union from typing import Dict, Union
from nonebot import get_driver
# 添加项目根目录到 Python 路径 # 添加项目根目录到 Python 路径
@@ -14,9 +13,6 @@ from src.common.logger import get_module_logger
logger = get_module_logger("scheduler") logger = get_module_logger("scheduler")
driver = get_driver()
config = driver.config
class ScheduleGenerator: class ScheduleGenerator:
enable_output: bool = True enable_output: bool = True
@@ -183,5 +179,7 @@ class ScheduleGenerator:
logger.info(f"时间[{time_str}]: 活动[{activity}]") logger.info(f"时间[{time_str}]: 活动[{activity}]")
logger.info("==================") logger.info("==================")
self.enable_output = False self.enable_output = False
# 当作为组件导入时使用的实例 # 当作为组件导入时使用的实例
bot_schedule = ScheduleGenerator() bot_schedule = ScheduleGenerator()

View File

@@ -1,4 +0,0 @@
更新版本后建议删除数据库messages中所有内容不然会出现报错
该操作不会影响你的记忆
如果显示配置文件版本过低运行根目录的bat