Merge pull request #237 from Rikki-Zero/debug
refactor: 修复database单例多次初始化的问题,改变instance默认返回实例的类型,缩短db相关函数调用时的object名
This commit is contained in:
@@ -32,18 +32,6 @@ _message_manager_started = False
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
Database.initialize(
|
||||
uri=os.getenv("MONGODB_URI"),
|
||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
||||
)
|
||||
logger.success("初始化数据库成功")
|
||||
|
||||
|
||||
# 初始化表情管理器
|
||||
emoji_manager.initialize()
|
||||
|
||||
|
||||
@@ -111,11 +111,11 @@ class ChatManager:
|
||||
|
||||
def _ensure_collection(self):
|
||||
"""确保数据库集合存在并创建索引"""
|
||||
if "chat_streams" not in self.db.db.list_collection_names():
|
||||
self.db.db.create_collection("chat_streams")
|
||||
if "chat_streams" not in self.db.list_collection_names():
|
||||
self.db.create_collection("chat_streams")
|
||||
# 创建索引
|
||||
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
|
||||
self.db.db.chat_streams.create_index(
|
||||
self.db.chat_streams.create_index([("stream_id", 1)], unique=True)
|
||||
self.db.chat_streams.create_index(
|
||||
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
|
||||
)
|
||||
|
||||
@@ -168,7 +168,7 @@ class ChatManager:
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
|
||||
data = self.db.chat_streams.find_one({"stream_id": stream_id})
|
||||
if data:
|
||||
stream = ChatStream.from_dict(data)
|
||||
# 更新用户信息和群组信息
|
||||
@@ -204,7 +204,7 @@ class ChatManager:
|
||||
async def _save_stream(self, stream: ChatStream):
|
||||
"""保存聊天流到数据库"""
|
||||
if not stream.saved:
|
||||
self.db.db.chat_streams.update_one(
|
||||
self.db.chat_streams.update_one(
|
||||
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
|
||||
)
|
||||
stream.saved = True
|
||||
@@ -216,7 +216,7 @@ class ChatManager:
|
||||
|
||||
async def load_all_streams(self):
|
||||
"""从数据库加载所有聊天流"""
|
||||
all_streams = self.db.db.chat_streams.find({})
|
||||
all_streams = self.db.chat_streams.find({})
|
||||
for data in all_streams:
|
||||
stream = ChatStream.from_dict(data)
|
||||
self.streams[stream.stream_id] = stream
|
||||
|
||||
@@ -76,16 +76,16 @@ class EmojiManager:
|
||||
|
||||
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
|
||||
"""
|
||||
if 'emoji' not in self.db.db.list_collection_names():
|
||||
self.db.db.create_collection('emoji')
|
||||
self.db.db.emoji.create_index([('embedding', '2dsphere')])
|
||||
self.db.db.emoji.create_index([('filename', 1)], unique=True)
|
||||
if 'emoji' not in self.db.list_collection_names():
|
||||
self.db.create_collection('emoji')
|
||||
self.db.emoji.create_index([('embedding', '2dsphere')])
|
||||
self.db.emoji.create_index([('filename', 1)], unique=True)
|
||||
|
||||
def record_usage(self, emoji_id: str):
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
self.db.db.emoji.update_one(
|
||||
self.db.emoji.update_one(
|
||||
{'_id': emoji_id},
|
||||
{'$inc': {'usage_count': 1}}
|
||||
)
|
||||
@@ -119,7 +119,7 @@ class EmojiManager:
|
||||
|
||||
try:
|
||||
# 获取所有表情包
|
||||
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
|
||||
all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("数据库中没有任何表情包")
|
||||
@@ -157,7 +157,7 @@ class EmojiManager:
|
||||
|
||||
if selected_emoji and 'path' in selected_emoji:
|
||||
# 更新使用次数
|
||||
self.db.db.emoji.update_one(
|
||||
self.db.emoji.update_one(
|
||||
{'_id': selected_emoji['_id']},
|
||||
{'$inc': {'usage_count': 1}}
|
||||
)
|
||||
@@ -239,7 +239,7 @@ class EmojiManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 检查是否已经注册过
|
||||
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
|
||||
existing_emoji = self.db['emoji'].find_one({'filename': filename})
|
||||
description = None
|
||||
|
||||
if existing_emoji:
|
||||
@@ -305,7 +305,7 @@ class EmojiManager:
|
||||
}
|
||||
|
||||
# 保存到emoji数据库
|
||||
self.db.db['emoji'].insert_one(emoji_record)
|
||||
self.db['emoji'].insert_one(emoji_record)
|
||||
logger.success(f"注册新表情包: {filename}")
|
||||
logger.info(f"描述: {description}")
|
||||
|
||||
@@ -346,7 +346,7 @@ class EmojiManager:
|
||||
try:
|
||||
self._ensure_db()
|
||||
# 获取所有表情包记录
|
||||
all_emojis = list(self.db.db.emoji.find())
|
||||
all_emojis = list(self.db.emoji.find())
|
||||
removed_count = 0
|
||||
total_count = len(all_emojis)
|
||||
|
||||
@@ -354,13 +354,13 @@ class EmojiManager:
|
||||
try:
|
||||
if 'path' not in emoji:
|
||||
logger.warning(f"发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}")
|
||||
self.db.db.emoji.delete_one({'_id': emoji['_id']})
|
||||
self.db.emoji.delete_one({'_id': emoji['_id']})
|
||||
removed_count += 1
|
||||
continue
|
||||
|
||||
if 'embedding' not in emoji:
|
||||
logger.warning(f"发现过时记录(缺少embedding字段),ID: {emoji.get('_id', 'unknown')}")
|
||||
self.db.db.emoji.delete_one({'_id': emoji['_id']})
|
||||
self.db.emoji.delete_one({'_id': emoji['_id']})
|
||||
removed_count += 1
|
||||
continue
|
||||
|
||||
@@ -368,7 +368,7 @@ class EmojiManager:
|
||||
if not os.path.exists(emoji['path']):
|
||||
logger.warning(f"表情包文件已被删除: {emoji['path']}")
|
||||
# 从数据库中删除记录
|
||||
result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
|
||||
result = self.db.emoji.delete_one({'_id': emoji['_id']})
|
||||
if result.deleted_count > 0:
|
||||
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
|
||||
removed_count += 1
|
||||
@@ -379,7 +379,7 @@ class EmojiManager:
|
||||
continue
|
||||
|
||||
# 验证清理结果
|
||||
remaining_count = self.db.db.emoji.count_documents({})
|
||||
remaining_count = self.db.emoji.count_documents({})
|
||||
if removed_count > 0:
|
||||
logger.success(f"已清理 {removed_count} 个失效的表情包记录")
|
||||
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")
|
||||
|
||||
@@ -154,7 +154,7 @@ class ResponseGenerator:
|
||||
reasoning_content: str,
|
||||
):
|
||||
"""保存对话记录到数据库"""
|
||||
self.db.db.reasoning_logs.insert_one(
|
||||
self.db.reasoning_logs.insert_one(
|
||||
{
|
||||
"time": time.time(),
|
||||
"chat_id": message.chat_stream.stream_id,
|
||||
|
||||
@@ -311,7 +311,7 @@ class PromptBuilder:
|
||||
{"$project": {"content": 1, "similarity": 1}}
|
||||
]
|
||||
|
||||
results = list(self.db.db.knowledges.aggregate(pipeline))
|
||||
results = list(self.db.knowledges.aggregate(pipeline))
|
||||
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
|
||||
|
||||
if not results:
|
||||
|
||||
@@ -168,7 +168,7 @@ class RelationshipManager:
|
||||
async def load_all_relationships(self):
|
||||
"""加载所有关系对象"""
|
||||
db = Database.get_instance()
|
||||
all_relationships = db.db.relationships.find({})
|
||||
all_relationships = db.relationships.find({})
|
||||
for data in all_relationships:
|
||||
await self.load_relationship(data)
|
||||
|
||||
@@ -176,7 +176,7 @@ class RelationshipManager:
|
||||
"""每5分钟自动保存一次关系数据"""
|
||||
db = Database.get_instance()
|
||||
# 获取所有关系记录
|
||||
all_relationships = db.db.relationships.find({})
|
||||
all_relationships = db.relationships.find({})
|
||||
# 依次加载每条记录
|
||||
for data in all_relationships:
|
||||
await self.load_relationship(data)
|
||||
@@ -206,7 +206,7 @@ class RelationshipManager:
|
||||
saved = relationship.saved
|
||||
|
||||
db = Database.get_instance()
|
||||
db.db.relationships.update_one(
|
||||
db.relationships.update_one(
|
||||
{'user_id': user_id, 'platform': platform},
|
||||
{'$set': {
|
||||
'platform': platform,
|
||||
|
||||
@@ -23,7 +23,7 @@ class MessageStorage:
|
||||
"detailed_plain_text": message.detailed_plain_text,
|
||||
"topic": topic,
|
||||
}
|
||||
self.db.db.messages.insert_one(message_data)
|
||||
self.db.messages.insert_one(message_data)
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
|
||||
|
||||
@@ -40,20 +40,20 @@ class ImageManager:
|
||||
|
||||
def _ensure_image_collection(self):
|
||||
"""确保images集合存在并创建索引"""
|
||||
if 'images' not in self.db.db.list_collection_names():
|
||||
self.db.db.create_collection('images')
|
||||
if 'images' not in self.db.list_collection_names():
|
||||
self.db.create_collection('images')
|
||||
# 创建索引
|
||||
self.db.db.images.create_index([('hash', 1)], unique=True)
|
||||
self.db.db.images.create_index([('url', 1)])
|
||||
self.db.db.images.create_index([('path', 1)])
|
||||
self.db.images.create_index([('hash', 1)], unique=True)
|
||||
self.db.images.create_index([('url', 1)])
|
||||
self.db.images.create_index([('path', 1)])
|
||||
|
||||
def _ensure_description_collection(self):
|
||||
"""确保image_descriptions集合存在并创建索引"""
|
||||
if 'image_descriptions' not in self.db.db.list_collection_names():
|
||||
self.db.db.create_collection('image_descriptions')
|
||||
if 'image_descriptions' not in self.db.list_collection_names():
|
||||
self.db.create_collection('image_descriptions')
|
||||
# 创建索引
|
||||
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
|
||||
self.db.db.image_descriptions.create_index([('type', 1)])
|
||||
self.db.image_descriptions.create_index([('hash', 1)], unique=True)
|
||||
self.db.image_descriptions.create_index([('type', 1)])
|
||||
|
||||
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
|
||||
"""从数据库获取图片描述
|
||||
@@ -65,7 +65,7 @@ class ImageManager:
|
||||
Returns:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
result= self.db.db.image_descriptions.find_one({
|
||||
result= self.db.image_descriptions.find_one({
|
||||
'hash': image_hash,
|
||||
'type': description_type
|
||||
})
|
||||
@@ -79,7 +79,7 @@ class ImageManager:
|
||||
description: 描述文本
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
"""
|
||||
self.db.db.image_descriptions.update_one(
|
||||
self.db.image_descriptions.update_one(
|
||||
{'hash': image_hash, 'type': description_type},
|
||||
{
|
||||
'$set': {
|
||||
@@ -121,7 +121,7 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 查重
|
||||
existing = self.db.db.images.find_one({'hash': image_hash})
|
||||
existing = self.db.images.find_one({'hash': image_hash})
|
||||
if existing:
|
||||
return existing['path']
|
||||
|
||||
@@ -142,7 +142,7 @@ class ImageManager:
|
||||
'description': description,
|
||||
'timestamp': timestamp
|
||||
}
|
||||
self.db.db.images.insert_one(image_doc)
|
||||
self.db.images.insert_one(image_doc)
|
||||
|
||||
return file_path
|
||||
|
||||
@@ -159,7 +159,7 @@ class ImageManager:
|
||||
"""
|
||||
try:
|
||||
# 先查找是否已存在
|
||||
existing = self.db.db.images.find_one({'url': url})
|
||||
existing = self.db.images.find_one({'url': url})
|
||||
if existing:
|
||||
return existing['path']
|
||||
|
||||
@@ -203,7 +203,7 @@ class ImageManager:
|
||||
Returns:
|
||||
bool: 是否存在
|
||||
"""
|
||||
return self.db.db.images.find_one({'url': url}) is not None
|
||||
return self.db.images.find_one({'url': url}) is not None
|
||||
|
||||
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
|
||||
"""检查图像是否已存在
|
||||
@@ -226,7 +226,7 @@ class ImageManager:
|
||||
return False
|
||||
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
return self.db.db.images.find_one({'hash': image_hash}) is not None
|
||||
return self.db.images.find_one({'hash': image_hash}) is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查哈希失败: {str(e)}")
|
||||
@@ -269,7 +269,7 @@ class ImageManager:
|
||||
'description': description,
|
||||
'timestamp': timestamp
|
||||
}
|
||||
self.db.db.images.update_one(
|
||||
self.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
@@ -326,7 +326,7 @@ class ImageManager:
|
||||
'description': description,
|
||||
'timestamp': timestamp
|
||||
}
|
||||
self.db.db.images.update_one(
|
||||
self.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
|
||||
Reference in New Issue
Block a user