Merge remote-tracking branch 'upstream/debug' into debug

This commit is contained in:
tcmofashi
2025-03-12 21:47:53 +08:00
20 changed files with 189 additions and 359 deletions

View File

@@ -7,7 +7,6 @@ from nonebot import get_driver, on_message, require
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment,MessageEvent
from nonebot.typing import T_State
from ...common.database import Database
from ..moods.moods import MoodManager # 导入情绪管理器
from ..schedule.schedule_generator import bot_schedule
from ..utils.statistic import LLMStatistics

View File

@@ -6,7 +6,7 @@ from typing import Dict, Optional
from loguru import logger
from ...common.database import Database
from ...common.database import db
from .message_base import GroupInfo, UserInfo
@@ -83,7 +83,6 @@ class ChatManager:
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.db = Database.get_instance()
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
@@ -111,11 +110,11 @@ class ChatManager:
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.list_collection_names():
self.db.create_collection("chat_streams")
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
# 创建索引
self.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.chat_streams.create_index(
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
@@ -168,7 +167,7 @@ class ChatManager:
return stream
# 检查数据库中是否存在
data = self.db.chat_streams.find_one({"stream_id": stream_id})
data = db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
@@ -204,7 +203,7 @@ class ChatManager:
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.chat_streams.update_one(
db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
@@ -216,7 +215,7 @@ class ChatManager:
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.chat_streams.find({})
all_streams = db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream

View File

@@ -12,7 +12,7 @@ import io
from loguru import logger
from nonebot import get_driver
from ...common.database import Database
from ...common.database import db
from ..chat.config import global_config
from ..chat.utils import get_embedding
from ..chat.utils_image import ImageManager, image_path_to_base64
@@ -30,12 +30,10 @@ class EmojiManager:
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
self.db = Database.get_instance()
self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(
@@ -50,7 +48,6 @@ class EmojiManager:
"""初始化数据库连接和表情目录"""
if not self._initialized:
try:
self.db = Database.get_instance()
self._ensure_emoji_collection()
self._ensure_emoji_dir()
self._initialized = True
@@ -78,16 +75,16 @@ class EmojiManager:
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if "emoji" not in self.db.list_collection_names():
self.db.create_collection("emoji")
self.db.emoji.create_index([("embedding", "2dsphere")])
self.db.emoji.create_index([("filename", 1)], unique=True)
if "emoji" not in db.list_collection_names():
db.create_collection("emoji")
db.emoji.create_index([("embedding", "2dsphere")])
db.emoji.create_index([("filename", 1)], unique=True)
def record_usage(self, emoji_id: str):
"""记录表情使用次数"""
try:
self._ensure_db()
self.db.emoji.update_one({"_id": emoji_id}, {"$inc": {"usage_count": 1}})
db.emoji.update_one({"_id": emoji_id}, {"$inc": {"usage_count": 1}})
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
@@ -118,7 +115,7 @@ class EmojiManager:
try:
# 获取所有表情包
all_emojis = list(self.db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1}))
all_emojis = list(db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1}))
if not all_emojis:
logger.warning("数据库中没有任何表情包")
@@ -155,7 +152,7 @@ class EmojiManager:
if selected_emoji and "path" in selected_emoji:
# 更新使用次数
self.db.emoji.update_one({"_id": selected_emoji["_id"]}, {"$inc": {"usage_count": 1}})
db.emoji.update_one({"_id": selected_emoji["_id"]}, {"$inc": {"usage_count": 1}})
logger.success(
f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})"
@@ -235,14 +232,14 @@ class EmojiManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 检查是否已经注册过
existing_emoji = self.db["emoji"].find_one({"filename": filename})
existing_emoji = db["emoji"].find_one({"filename": filename})
description = None
if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get("discription")
# 检查是否在images集合中存在
existing_image = image_manager.db.images.find_one({"hash": image_hash})
existing_image = db.images.find_one({"hash": image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
@@ -252,7 +249,7 @@ class EmojiManager:
"description": description,
"timestamp": int(time.time()),
}
image_manager.db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, "emoji")
logger.success(f"同步已存在的表情包到images集合: {filename}")
@@ -295,7 +292,7 @@ class EmojiManager:
}
# 保存到emoji数据库
self.db["emoji"].insert_one(emoji_record)
db["emoji"].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}")
@@ -307,7 +304,7 @@ class EmojiManager:
"description": description,
"timestamp": int(time.time()),
}
image_manager.db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, "emoji")
logger.success(f"同步保存到images集合: {filename}")
@@ -331,7 +328,7 @@ class EmojiManager:
try:
self._ensure_db()
# 获取所有表情包记录
all_emojis = list(self.db.emoji.find())
all_emojis = list(db.emoji.find())
removed_count = 0
total_count = len(all_emojis)
@@ -339,26 +336,26 @@ class EmojiManager:
try:
if "path" not in emoji:
logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
self.db.emoji.delete_one({"_id": emoji["_id"]})
db.emoji.delete_one({"_id": emoji["_id"]})
removed_count += 1
continue
if "embedding" not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.emoji.delete_one({"_id": emoji["_id"]})
db.emoji.delete_one({"_id": emoji["_id"]})
removed_count += 1
continue
if "hash" not in emoji:
logger.warning(f"发现缺失记录缺少hash字段ID: {emoji.get('_id', 'unknown')}")
hash = hashlib.md5(open(emoji["path"], "rb").read()).hexdigest()
self.db.emoji.update_one({"_id": emoji["_id"]}, {"$set": {"hash": hash}})
db.emoji.update_one({"_id": emoji["_id"]}, {"$set": {"hash": hash}})
# 检查文件是否存在
if not os.path.exists(emoji["path"]):
logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录
result = self.db.emoji.delete_one({"_id": emoji["_id"]})
result = db.emoji.delete_one({"_id": emoji["_id"]})
if result.deleted_count > 0:
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1
@@ -369,7 +366,7 @@ class EmojiManager:
continue
# 验证清理结果
remaining_count = self.db.emoji.count_documents({})
remaining_count = db.emoji.count_documents({})
if removed_count > 0:
logger.success(f"已清理 {removed_count} 个失效的表情包记录")
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")

View File

@@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union
from nonebot import get_driver
from loguru import logger
from ...common.database import Database
from ...common.database import db
from ..models.utils_model import LLM_request
from .config import global_config
from .message import MessageRecv, MessageThinking, Message
@@ -34,7 +34,6 @@ class ResponseGenerator:
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.db = Database.get_instance()
self.current_model_type = "r1" # 默认使用 R1
async def generate_response(
@@ -154,7 +153,7 @@ class ResponseGenerator:
reasoning_content: str,
):
"""保存对话记录到数据库"""
self.db.reasoning_logs.insert_one(
db.reasoning_logs.insert_one(
{
"time": time.time(),
"chat_id": message.chat_stream.stream_id,
@@ -211,7 +210,6 @@ class ResponseGenerator:
class InitiativeMessageGenerate:
def __init__(self):
self.db = Database.get_instance()
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
self.model_r1_distill = LLM_request(

View File

@@ -3,7 +3,7 @@ import time
from typing import Optional
from loguru import logger
from ...common.database import Database
from ...common.database import db
from ..memory_system.memory import hippocampus, memory_graph
from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule
@@ -16,7 +16,6 @@ class PromptBuilder:
def __init__(self):
self.prompt_built = ''
self.activate_messages = ''
self.db = Database.get_instance()
@@ -76,7 +75,7 @@ class PromptBuilder:
chat_in_group=True
chat_talking_prompt = ''
if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_talking_prompt = get_recent_group_detailed_plain_text(stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_stream=chat_manager.get_stream(stream_id)
if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
@@ -199,7 +198,7 @@ class PromptBuilder:
chat_talking_prompt = ''
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
chat_talking_prompt = get_recent_group_detailed_plain_text(group_id,
limit=global_config.MAX_CONTEXT_SIZE,
combine=True)
@@ -311,7 +310,7 @@ class PromptBuilder:
{"$project": {"content": 1, "similarity": 1}}
]
results = list(self.db.knowledges.aggregate(pipeline))
results = list(db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results:

View File

@@ -2,7 +2,7 @@ import asyncio
from typing import Optional
from loguru import logger
from ...common.database import Database
from ...common.database import db
from .message_base import UserInfo
from .chat_stream import ChatStream
@@ -167,14 +167,12 @@ class RelationshipManager:
async def load_all_relationships(self):
"""加载所有关系对象"""
db = Database.get_instance()
all_relationships = db.relationships.find({})
for data in all_relationships:
await self.load_relationship(data)
async def _start_relationship_manager(self):
"""每5分钟自动保存一次关系数据"""
db = Database.get_instance()
# 获取所有关系记录
all_relationships = db.relationships.find({})
# 依次加载每条记录
@@ -205,7 +203,6 @@ class RelationshipManager:
age = relationship.age
saved = relationship.saved
db = Database.get_instance()
db.relationships.update_one(
{'user_id': user_id, 'platform': platform},
{'$set': {

View File

@@ -1,15 +1,12 @@
from typing import Optional, Union
from ...common.database import Database
from ...common.database import db
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from loguru import logger
class MessageStorage:
def __init__(self):
self.db = Database.get_instance()
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库"""
try:
@@ -23,7 +20,7 @@ class MessageStorage:
"detailed_plain_text": message.detailed_plain_text,
"topic": topic,
}
self.db.messages.insert_one(message_data)
db.messages.insert_one(message_data)
except Exception:
logger.exception("存储消息失败")

View File

@@ -16,6 +16,7 @@ from .message import MessageRecv,Message
from .message_base import UserInfo
from .chat_stream import ChatStream
from ..moods.moods import MoodManager
from ...common.database import db
driver = get_driver()
config = driver.config
@@ -76,11 +77,10 @@ def calculate_information_content(text):
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
def get_closest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录
Args:
db: 数据库实例
length: 要获取的消息数量
timestamp: 时间戳
@@ -115,11 +115,10 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
return []
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
async def get_recent_group_messages(chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
Args:
db: Database实例
group_id: 群组ID
limit: 获取消息数量默认12条
@@ -161,7 +160,7 @@ async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
return message_objects
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
def get_recent_group_detailed_plain_text(chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.messages.find(
{"chat_id": chat_stream_id},
{

View File

@@ -10,7 +10,7 @@ import io
from loguru import logger
from nonebot import get_driver
from ...common.database import Database
from ...common.database import db
from ..chat.config import global_config
from ..models.utils_model import LLM_request
@@ -25,13 +25,11 @@ class ImageManager:
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.db = Database.get_instance()
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
@@ -44,20 +42,20 @@ class ImageManager:
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if "images" not in self.db.list_collection_names():
self.db.create_collection("images")
if "images" not in db.list_collection_names():
db.create_collection("images")
# 创建索引
self.db.images.create_index([("hash", 1)], unique=True)
self.db.images.create_index([("url", 1)])
self.db.images.create_index([("path", 1)])
db.images.create_index([("hash", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in self.db.list_collection_names():
self.db.create_collection("image_descriptions")
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
# 创建索引
self.db.image_descriptions.create_index([("hash", 1)], unique=True)
self.db.image_descriptions.create_index([("type", 1)])
db.image_descriptions.create_index([("hash", 1)], unique=True)
db.image_descriptions.create_index([("type", 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -69,9 +67,7 @@ class ImageManager:
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
if image_hash is None:
return
result = self.db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
return result["description"] if result else None
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
@@ -82,9 +78,7 @@ class ImageManager:
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
if image_hash is None:
return
self.db.image_descriptions.update_one(
db.image_descriptions.update_one(
{"hash": image_hash, "type": description_type},
{"$set": {"description": description, "timestamp": int(time.time())}},
upsert=True,
@@ -120,7 +114,7 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查重
existing = self.db.images.find_one({"hash": image_hash})
existing = db.images.find_one({"hash": image_hash})
if existing:
return existing["path"]
@@ -141,7 +135,7 @@ class ImageManager:
"description": description,
"timestamp": timestamp,
}
self.db.images.insert_one(image_doc)
db.images.insert_one(image_doc)
return file_path
@@ -158,7 +152,7 @@ class ImageManager:
"""
try:
# 先查找是否已存在
existing = self.db.images.find_one({"url": url})
existing = db.images.find_one({"url": url})
if existing:
return existing["path"]
@@ -201,7 +195,7 @@ class ImageManager:
Returns:
bool: 是否存在
"""
return self.db.images.find_one({"url": url}) is not None
return db.images.find_one({"url": url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
@@ -224,7 +218,7 @@ class ImageManager:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.images.find_one({"hash": image_hash}) is not None
return db.images.find_one({"hash": image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
@@ -268,7 +262,7 @@ class ImageManager:
"description": description,
"timestamp": timestamp,
}
self.db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.success(f"保存表情包: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
@@ -328,7 +322,7 @@ class ImageManager:
"description": description,
"timestamp": timestamp,
}
self.db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.success(f"保存图片: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")

View File

@@ -13,7 +13,7 @@ from loguru import logger
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import Database # 使用正确的导入语法
from src.common.database import db # 使用正确的导入语法
# 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
@@ -23,7 +23,6 @@ load_dotenv(env_path)
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2)
@@ -96,7 +95,7 @@ class Memory_graph:
dot_data = {
"concept": node
}
self.db.store_memory_dots.insert_one(dot_data)
db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
@@ -106,7 +105,7 @@ class Memory_graph:
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
@@ -115,7 +114,7 @@ class Memory_graph:
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(
self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
@@ -130,50 +129,39 @@ class Memory_graph:
def save_graph_to_db(self):
# 清空现有的图数据
self.db.graph_data.delete_many({})
db.graph_data.delete_many({})
# 保存节点
for node in self.G.nodes(data=True):
node_data = {
'concept': node[0],
'memory_items': node[1].get('memory_items', []) # 默认为空列表
}
self.db.graph_data.nodes.insert_one(node_data)
db.graph_data.nodes.insert_one(node_data)
# 保存边
for edge in self.G.edges():
edge_data = {
'source': edge[0],
'target': edge[1]
}
self.db.graph_data.edges.insert_one(edge_data)
db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = self.db.graph_data.nodes.find()
nodes = db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边
edges = self.db.graph_data.edges.find()
edges = db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge['source'], edge['target'])
def main():
# 初始化数据库
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
memory_graph = Memory_graph()
memory_graph.load_graph_from_db()

View File

@@ -10,12 +10,12 @@ import networkx as nx
from loguru import logger
from nonebot import get_driver
from ...common.database import Database # 使用正确的导入语法
from ...common.database import db # 使用正确的导入语法
from ..chat.config import global_config
from ..chat.utils import (
calculate_information_content,
cosine_similarity,
get_cloest_chat_from_db,
get_closest_chat_from_db,
text_to_vector,
)
from ..models.utils_model import LLM_request
@@ -23,7 +23,6 @@ from ..models.utils_model import LLM_request
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
# 避免自连接
@@ -191,19 +190,19 @@ class Hippocampus:
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600, 3600 * 4)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
@@ -349,7 +348,7 @@ class Hippocampus:
def sync_memory_to_db(self):
"""检查并同步内存中的图结构与数据库"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
db_nodes = list(db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
@@ -377,7 +376,7 @@ class Hippocampus:
'created_time': created_time,
'last_modified': last_modified
}
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
db.graph_data.nodes.insert_one(node_data)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
@@ -385,7 +384,7 @@ class Hippocampus:
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
self.memory_graph.db.graph_data.nodes.update_one(
db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'memory_items': memory_items,
@@ -396,7 +395,7 @@ class Hippocampus:
)
# 处理边的信息
db_edges = list(self.memory_graph.db.graph_data.edges.find())
db_edges = list(db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -428,11 +427,11 @@ class Hippocampus:
'created_time': created_time,
'last_modified': last_modified
}
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash:
self.memory_graph.db.graph_data.edges.update_one(
db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {
'hash': edge_hash,
@@ -451,7 +450,7 @@ class Hippocampus:
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = list(self.memory_graph.db.graph_data.nodes.find())
nodes = list(db.graph_data.nodes.find())
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
@@ -468,7 +467,7 @@ class Hippocampus:
if 'last_modified' not in node:
update_data['last_modified'] = current_time
self.memory_graph.db.graph_data.nodes.update_one(
db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': update_data}
)
@@ -485,7 +484,7 @@ class Hippocampus:
last_modified=last_modified)
# 从数据库加载所有边
edges = list(self.memory_graph.db.graph_data.edges.find())
edges = list(db.graph_data.edges.find())
for edge in edges:
source = edge['source']
target = edge['target']
@@ -501,7 +500,7 @@ class Hippocampus:
if 'last_modified' not in edge:
update_data['last_modified'] = current_time
self.memory_graph.db.graph_data.edges.update_one(
db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': update_data}
)

View File

@@ -19,7 +19,7 @@ import jieba
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
from src.common.database import Database
from src.common.database import db
from src.plugins.memory_system.offline_llm import LLMModel
# 获取当前文件的目录
@@ -49,7 +49,7 @@ def calculate_information_content(text):
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
def get_closest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
Returns:
@@ -91,7 +91,6 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
# 如果边已存在,增加 strength
@@ -186,19 +185,19 @@ class Hippocampus:
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600*4)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600*4, 3600*24)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600*24, 3600*24*7)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
@@ -323,7 +322,7 @@ class Hippocampus:
self.memory_graph.G.clear()
# 从数据库加载所有节点
nodes = self.memory_graph.db.graph_data.nodes.find()
nodes = db.graph_data.nodes.find()
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
@@ -334,7 +333,7 @@ class Hippocampus:
self.memory_graph.G.add_node(concept, memory_items=memory_items)
# 从数据库加载所有边
edges = self.memory_graph.db.graph_data.edges.find()
edges = db.graph_data.edges.find()
for edge in edges:
source = edge['source']
target = edge['target']
@@ -371,7 +370,7 @@ class Hippocampus:
使用特征值(哈希值)快速判断是否需要更新
"""
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
db_nodes = list(db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
@@ -394,7 +393,7 @@ class Hippocampus:
'memory_items': memory_items,
'hash': memory_hash
}
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
db.graph_data.nodes.insert_one(node_data)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
@@ -403,7 +402,7 @@ class Hippocampus:
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
# logger.info(f"更新节点内容: {concept}")
self.memory_graph.db.graph_data.nodes.update_one(
db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'memory_items': memory_items,
@@ -416,10 +415,10 @@ class Hippocampus:
for db_node in db_nodes:
if db_node['concept'] not in memory_concepts:
# logger.info(f"删除多余节点: {db_node['concept']}")
self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息
db_edges = list(self.memory_graph.db.graph_data.edges.find())
db_edges = list(db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges())
# 创建边的哈希值字典
@@ -445,12 +444,12 @@ class Hippocampus:
'num': 1,
'hash': edge_hash
}
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash:
logger.info(f"更新边: {source} - {target}")
self.memory_graph.db.graph_data.edges.update_one(
db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {'hash': edge_hash}}
)
@@ -461,7 +460,7 @@ class Hippocampus:
if edge_key not in memory_edge_set:
source, target = edge_key
logger.info(f"删除多余边: {source} - {target}")
self.memory_graph.db.graph_data.edges.delete_one({
db.graph_data.edges.delete_one({
'source': source,
'target': target
})
@@ -487,9 +486,9 @@ class Hippocampus:
topic: 要删除的节点概念
"""
# 删除节点
self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic})
db.graph_data.nodes.delete_one({'concept': topic})
# 删除所有涉及该节点的边
self.memory_graph.db.graph_data.edges.delete_many({
db.graph_data.edges.delete_many({
'$or': [
{'source': topic},
{'target': topic}
@@ -902,17 +901,6 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
plt.show()
async def main():
# 初始化数据库
logger.info("正在初始化数据库连接...")
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
start_time = time.time()
test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}

View File

@@ -38,7 +38,7 @@ import jieba
# from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database
from src.common.database import db
from src.plugins.memory_system.offline_llm import LLMModel
# 获取当前文件的目录
@@ -56,45 +56,6 @@ else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
class Database:
_instance = None
db = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if not Database.db:
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
try:
if username and password:
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
else:
uri = f"mongodb://{host}:{port}"
client = pymongo.MongoClient(uri)
cls.db = client[db_name]
# 测试连接
client.server_info()
logger.success("MongoDB连接成功!")
except Exception as e:
logger.error(f"初始化MongoDB失败: {str(e)}")
raise
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
@@ -108,7 +69,7 @@ def calculate_information_content(text):
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
def get_closest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
Returns:
@@ -163,7 +124,7 @@ class Memory_cortex:
default_time = datetime.datetime.now().timestamp()
# 从数据库加载所有节点
nodes = self.memory_graph.db.graph_data.nodes.find()
nodes = db.graph_data.nodes.find()
for node in nodes:
concept = node['concept']
memory_items = node.get('memory_items', [])
@@ -180,7 +141,7 @@ class Memory_cortex:
created_time = default_time
last_modified = default_time
# 更新数据库中的节点
self.memory_graph.db.graph_data.nodes.update_one(
db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'created_time': created_time,
@@ -196,7 +157,7 @@ class Memory_cortex:
last_modified=last_modified)
# 从数据库加载所有边
edges = self.memory_graph.db.graph_data.edges.find()
edges = db.graph_data.edges.find()
for edge in edges:
source = edge['source']
target = edge['target']
@@ -212,7 +173,7 @@ class Memory_cortex:
created_time = default_time
last_modified = default_time
# 更新数据库中的边
self.memory_graph.db.graph_data.edges.update_one(
db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {
'created_time': created_time,
@@ -256,7 +217,7 @@ class Memory_cortex:
current_time = datetime.datetime.now().timestamp()
# 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.graph_data.nodes.find())
db_nodes = list(db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找
@@ -280,7 +241,7 @@ class Memory_cortex:
'created_time': data.get('created_time', current_time),
'last_modified': data.get('last_modified', current_time)
}
self.memory_graph.db.graph_data.nodes.insert_one(node_data)
db.graph_data.nodes.insert_one(node_data)
else:
# 获取数据库中节点的特征值
db_node = db_nodes_dict[concept]
@@ -288,7 +249,7 @@ class Memory_cortex:
# 如果特征值不同,则更新节点
if db_hash != memory_hash:
self.memory_graph.db.graph_data.nodes.update_one(
db.graph_data.nodes.update_one(
{'concept': concept},
{'$set': {
'memory_items': memory_items,
@@ -301,10 +262,10 @@ class Memory_cortex:
memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes:
if db_node['concept'] not in memory_concepts:
self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']})
db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息
db_edges = list(self.memory_graph.db.graph_data.edges.find())
db_edges = list(db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典
@@ -332,11 +293,11 @@ class Memory_cortex:
'created_time': data.get('created_time', current_time),
'last_modified': data.get('last_modified', current_time)
}
self.memory_graph.db.graph_data.edges.insert_one(edge_data)
db.graph_data.edges.insert_one(edge_data)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash:
self.memory_graph.db.graph_data.edges.update_one(
db.graph_data.edges.update_one(
{'source': source, 'target': target},
{'$set': {
'hash': edge_hash,
@@ -350,7 +311,7 @@ class Memory_cortex:
for edge_key in db_edge_dict:
if edge_key not in memory_edge_set:
source, target = edge_key
self.memory_graph.db.graph_data.edges.delete_one({
db.graph_data.edges.delete_one({
'source': source,
'target': target
})
@@ -365,9 +326,9 @@ class Memory_cortex:
topic: 要删除的节点概念
"""
# 删除节点
self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic})
db.graph_data.nodes.delete_one({'concept': topic})
# 删除所有涉及该节点的边
self.memory_graph.db.graph_data.edges.delete_many({
db.graph_data.edges.delete_many({
'$or': [
{'source': topic},
{'target': topic}
@@ -377,7 +338,6 @@ class Memory_cortex:
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2):
# 避免自连接
@@ -492,19 +452,19 @@ class Hippocampus:
# 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600*4)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600*4, 3600*24)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600*24, 3600*24*7)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time)
messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages:
chat_samples.append(messages)
@@ -1134,7 +1094,6 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
async def main():
# 初始化数据库
logger.info("正在初始化数据库连接...")
db = Database.get_instance()
start_time = time.time()
test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}

View File

@@ -10,7 +10,7 @@ from nonebot import get_driver
import base64
from PIL import Image
import io
from ...common.database import Database
from ...common.database import db
from ..chat.config import global_config
driver = get_driver()
@@ -34,17 +34,16 @@ class LLM_request:
self.pri_out = model.get("pri_out", 0)
# 获取数据库实例
self.db = Database.get_instance()
self._init_database()
def _init_database(self):
"""初始化数据库集合"""
try:
# 创建llm_usage集合的索引
self.db.llm_usage.create_index([("timestamp", 1)])
self.db.llm_usage.create_index([("model_name", 1)])
self.db.llm_usage.create_index([("user_id", 1)])
self.db.llm_usage.create_index([("request_type", 1)])
db.llm_usage.create_index([("timestamp", 1)])
db.llm_usage.create_index([("model_name", 1)])
db.llm_usage.create_index([("user_id", 1)])
db.llm_usage.create_index([("request_type", 1)])
except Exception:
logger.error("创建数据库索引失败")
@@ -73,7 +72,7 @@ class LLM_request:
"status": "success",
"timestamp": datetime.now()
}
self.db.llm_usage.insert_one(usage_data)
db.llm_usage.insert_one(usage_data)
logger.info(
f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "

View File

@@ -8,7 +8,7 @@ from nonebot import get_driver
from src.plugins.chat.config import global_config
from ...common.database import Database # 使用正确的导入语法
from ...common.database import db # 使用正确的导入语法
from ..models.utils_model import LLM_request
driver = get_driver()
@@ -19,7 +19,6 @@ class ScheduleGenerator:
# 根据global_config.llm_normal这一字典配置指定模型
# self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9)
self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9)
self.db = Database.get_instance()
self.today_schedule_text = ""
self.today_schedule = {}
self.tomorrow_schedule_text = ""
@@ -46,7 +45,7 @@ class ScheduleGenerator:
schedule_text = str
existing_schedule = self.db.schedule.find_one({"date": date_str})
existing_schedule = db.schedule.find_one({"date": date_str})
if existing_schedule:
logger.debug(f"{date_str}的日程已存在:")
schedule_text = existing_schedule["schedule"]
@@ -63,7 +62,7 @@ class ScheduleGenerator:
try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
self.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
except Exception as e:
logger.error(f"生成日程失败: {str(e)}")
schedule_text = "生成日程时出错了"
@@ -143,7 +142,7 @@ class ScheduleGenerator:
"""打印完整的日程安排"""
if not self._parse_schedule(self.today_schedule_text):
logger.warning("今日日程有误,将在下次运行时重新生成")
self.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else:
logger.info("=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items():

View File

@@ -5,7 +5,7 @@ from datetime import datetime, timedelta
from typing import Any, Dict
from loguru import logger
from ...common.database import Database
from ...common.database import db
class LLMStatistics:
@@ -15,7 +15,6 @@ class LLMStatistics:
Args:
output_file: 统计结果输出文件路径
"""
self.db = Database.get_instance()
self.output_file = output_file
self.running = False
self.stats_thread = None
@@ -53,7 +52,7 @@ class LLMStatistics:
"costs_by_model": defaultdict(float)
}
cursor = self.db.llm_usage.find({
cursor = db.llm_usage.find({
"timestamp": {"$gte": start_time}
})

View File

@@ -14,7 +14,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 现在可以导入src模块
from src.common.database import Database
from src.common.database import db
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.prod")
@@ -24,18 +24,6 @@ load_dotenv(env_path)
class KnowledgeLibrary:
def __init__(self):
# 初始化数据库连接
if Database._instance is None:
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
@@ -176,7 +164,7 @@ class KnowledgeLibrary:
try:
current_hash = self.calculate_file_hash(file_path)
processed_record = self.db.processed_files.find_one({"file_path": file_path})
processed_record = db.processed_files.find_one({"file_path": file_path})
if processed_record:
if processed_record.get("hash") == current_hash:
@@ -197,14 +185,14 @@ class KnowledgeLibrary:
"split_length": knowledge_length,
"created_at": datetime.now()
}
self.db.knowledges.insert_one(knowledge)
db.knowledges.insert_one(knowledge)
result["chunks_processed"] += 1
split_by = processed_record.get("split_by", []) if processed_record else []
if knowledge_length not in split_by:
split_by.append(knowledge_length)
self.db.knowledges.processed_files.update_one(
db.knowledges.processed_files.update_one(
{"file_path": file_path},
{
"$set": {
@@ -322,7 +310,7 @@ class KnowledgeLibrary:
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
]
results = list(self.db.knowledges.aggregate(pipeline))
results = list(db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
@@ -346,7 +334,7 @@ if __name__ == "__main__":
elif choice == '2':
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
if confirm == 'y':
knowledge_library.db.knowledges.delete_many({})
db.knowledges.delete_many({})
console.print("[green]已清空所有知识![/green]")
continue
elif choice == '1':