refactor: 修复database单例多次初始化的问题,改变instance默认返回实例的类型,缩短db相关函数调用时的object名

This commit is contained in:
Rikki
2025-03-12 00:51:56 +08:00
parent 7c4d3ec3fb
commit 39018440d7
17 changed files with 88 additions and 117 deletions

View File

@@ -32,18 +32,6 @@ _message_manager_started = False
driver = get_driver()
config = driver.config
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
logger.success("初始化数据库成功")
# 初始化表情管理器
emoji_manager.initialize()

View File

@@ -111,11 +111,11 @@ class ChatManager:
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.db.list_collection_names():
self.db.db.create_collection("chat_streams")
if "chat_streams" not in self.db.list_collection_names():
self.db.create_collection("chat_streams")
# 创建索引
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index(
self.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
@@ -168,7 +168,7 @@ class ChatManager:
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
data = self.db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
@@ -204,7 +204,7 @@ class ChatManager:
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
self.db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
@@ -216,7 +216,7 @@ class ChatManager:
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})
all_streams = self.db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream

View File

@@ -76,16 +76,16 @@ class EmojiManager:
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if 'emoji' not in self.db.db.list_collection_names():
self.db.db.create_collection('emoji')
self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('filename', 1)], unique=True)
if 'emoji' not in self.db.list_collection_names():
self.db.create_collection('emoji')
self.db.emoji.create_index([('embedding', '2dsphere')])
self.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str):
"""记录表情使用次数"""
try:
self._ensure_db()
self.db.db.emoji.update_one(
self.db.emoji.update_one(
{'_id': emoji_id},
{'$inc': {'usage_count': 1}}
)
@@ -119,7 +119,7 @@ class EmojiManager:
try:
# 获取所有表情包
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
if not all_emojis:
logger.warning("数据库中没有任何表情包")
@@ -157,7 +157,7 @@ class EmojiManager:
if selected_emoji and 'path' in selected_emoji:
# 更新使用次数
self.db.db.emoji.update_one(
self.db.emoji.update_one(
{'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}}
)
@@ -236,7 +236,7 @@ class EmojiManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
existing_emoji = self.db['emoji'].find_one({'filename': filename})
description = None
if existing_emoji:
@@ -298,7 +298,7 @@ class EmojiManager:
}
# 保存到emoji数据库
self.db.db['emoji'].insert_one(emoji_record)
self.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}")
@@ -338,7 +338,7 @@ class EmojiManager:
try:
self._ensure_db()
# 获取所有表情包记录
all_emojis = list(self.db.db.emoji.find())
all_emojis = list(self.db.emoji.find())
removed_count = 0
total_count = len(all_emojis)
@@ -346,13 +346,13 @@ class EmojiManager:
try:
if 'path' not in emoji:
logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']})
self.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']})
self.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
@@ -360,7 +360,7 @@ class EmojiManager:
if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录
result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
result = self.db.emoji.delete_one({'_id': emoji['_id']})
if result.deleted_count > 0:
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1
@@ -371,7 +371,7 @@ class EmojiManager:
continue
# 验证清理结果
remaining_count = self.db.db.emoji.count_documents({})
remaining_count = self.db.emoji.count_documents({})
if removed_count > 0:
logger.success(f"已清理 {removed_count} 个失效的表情包记录")
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")

View File

@@ -154,7 +154,7 @@ class ResponseGenerator:
reasoning_content: str,
):
"""保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one(
self.db.reasoning_logs.insert_one(
{
"time": time.time(),
"chat_id": message.chat_stream.stream_id,

View File

@@ -311,7 +311,7 @@ class PromptBuilder:
{"$project": {"content": 1, "similarity": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
results = list(self.db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results:

View File

@@ -169,7 +169,7 @@ class RelationshipManager:
async def load_all_relationships(self):
"""加载所有关系对象"""
db = Database.get_instance()
all_relationships = db.db.relationships.find({})
all_relationships = db.relationships.find({})
for data in all_relationships:
await self.load_relationship(data)
@@ -177,7 +177,7 @@ class RelationshipManager:
"""每5分钟自动保存一次关系数据"""
db = Database.get_instance()
# 获取所有关系记录
all_relationships = db.db.relationships.find({})
all_relationships = db.relationships.find({})
# 依次加载每条记录
for data in all_relationships:
await self.load_relationship(data)
@@ -207,7 +207,7 @@ class RelationshipManager:
saved = relationship.saved
db = Database.get_instance()
db.db.relationships.update_one(
db.relationships.update_one(
{'user_id': user_id, 'platform': platform},
{'$set': {
'platform': platform,

View File

@@ -25,7 +25,7 @@ class MessageStorage:
"detailed_plain_text": message.detailed_plain_text,
"topic": topic,
}
self.db.db.messages.insert_one(message_data)
self.db.messages.insert_one(message_data)
except Exception:
logger.exception("存储消息失败")

View File

@@ -44,20 +44,20 @@ class ImageManager:
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if 'images' not in self.db.db.list_collection_names():
self.db.db.create_collection('images')
if 'images' not in self.db.list_collection_names():
self.db.create_collection('images')
# 创建索引
self.db.db.images.create_index([('hash', 1)], unique=True)
self.db.db.images.create_index([('url', 1)])
self.db.db.images.create_index([('path', 1)])
self.db.images.create_index([('hash', 1)], unique=True)
self.db.images.create_index([('url', 1)])
self.db.images.create_index([('path', 1)])
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.db.list_collection_names():
self.db.db.create_collection('image_descriptions')
if 'image_descriptions' not in self.db.list_collection_names():
self.db.create_collection('image_descriptions')
# 创建索引
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)])
self.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.image_descriptions.create_index([('type', 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -69,7 +69,7 @@ class ImageManager:
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result= self.db.db.image_descriptions.find_one({
result= self.db.image_descriptions.find_one({
'hash': image_hash,
'type': description_type
})
@@ -83,7 +83,7 @@ class ImageManager:
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
self.db.db.image_descriptions.update_one(
self.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type},
{
'$set': {
@@ -125,7 +125,7 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查重
existing = self.db.db.images.find_one({'hash': image_hash})
existing = self.db.images.find_one({'hash': image_hash})
if existing:
return existing['path']
@@ -146,7 +146,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.db.images.insert_one(image_doc)
self.db.images.insert_one(image_doc)
return file_path
@@ -163,7 +163,7 @@ class ImageManager:
"""
try:
# 先查找是否已存在
existing = self.db.db.images.find_one({'url': url})
existing = self.db.images.find_one({'url': url})
if existing:
return existing['path']
@@ -207,7 +207,7 @@ class ImageManager:
Returns:
bool: 是否存在
"""
return self.db.db.images.find_one({'url': url}) is not None
return self.db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
@@ -230,7 +230,7 @@ class ImageManager:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.db.images.find_one({'hash': image_hash}) is not None
return self.db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
@@ -273,7 +273,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
self.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
@@ -330,7 +330,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
self.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True