Merge pull request #237 from Rikki-Zero/debug

refactor: 修复database单例多次初始化的问题,改变instance默认返回实例的类型,缩短db相关函数调用时的object名
This commit is contained in:
AL76
2025-03-12 02:48:00 +08:00
committed by GitHub
16 changed files with 105 additions and 101 deletions

15
bot.py
View File

@@ -12,6 +12,8 @@ from loguru import logger
from nonebot.adapters.onebot.v11 import Adapter
import platform
from src.common.database import Database
# 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ}
@@ -96,6 +98,17 @@ def load_env():
logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
def init_database():
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
def load_logger():
logger.remove() # 移除默认配置
@@ -198,6 +211,7 @@ def raw_main():
init_config()
init_env()
load_env()
init_database() # 加载完成环境后初始化database
load_logger()
env_config = {key: os.getenv(key) for key in os.environ}
@@ -223,7 +237,6 @@ def raw_main():
if __name__ == "__main__":
try:
raw_main()

View File

@@ -1,5 +1,6 @@
from typing import Optional
from pymongo import MongoClient
from pymongo.database import Database as MongoDatabase
class Database:
_instance: Optional["Database"] = None
@@ -25,7 +26,7 @@ class Database:
else:
# 否则使用无认证连接
self.client = MongoClient(host, port)
self.db = self.client[db_name]
self.db: MongoDatabase = self.client[db_name]
@classmethod
def initialize(
@@ -37,15 +38,36 @@ class Database:
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
) -> "Database":
) -> MongoDatabase:
if cls._instance is None:
cls._instance = cls(
host, port, db_name, username, password, auth_source, uri
)
return cls._instance
return cls._instance.db
@classmethod
def get_instance(cls) -> "Database":
def get_instance(cls) -> MongoDatabase:
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance
return cls._instance.db
#测试用
def get_random_group_messages(self, group_id: str, limit: int = 5):
# 先随机获取一条消息
random_message = list(self.db.messages.aggregate([
{"$match": {"group_id": group_id}},
{"$sample": {"size": 1}}
]))[0]
# 获取该消息之后的消息
subsequent_messages = list(self.db.messages.find({
"group_id": group_id,
"time": {"$gt": random_message["time"]}
}).sort("time", 1).limit(limit))
# 将随机消息和后续消息合并
messages = [random_message] + subsequent_messages
return messages

View File

@@ -46,7 +46,7 @@ class ReasoningGUI:
# 初始化数据库连接
try:
self.db = Database.get_instance().db
self.db = Database.get_instance()
logger.success("数据库连接成功")
except RuntimeError:
logger.warning("数据库未初始化,正在尝试初始化...")
@@ -60,7 +60,7 @@ class ReasoningGUI:
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance().db
self.db = Database.get_instance()
logger.success("数据库初始化成功")
except Exception:
logger.exception("数据库初始化失败")

View File

@@ -32,18 +32,6 @@ _message_manager_started = False
driver = get_driver()
config = driver.config
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
logger.success("初始化数据库成功")
# 初始化表情管理器
emoji_manager.initialize()

View File

@@ -111,11 +111,11 @@ class ChatManager:
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.db.list_collection_names():
self.db.db.create_collection("chat_streams")
if "chat_streams" not in self.db.list_collection_names():
self.db.create_collection("chat_streams")
# 创建索引
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index(
self.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
@@ -168,7 +168,7 @@ class ChatManager:
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
data = self.db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
@@ -204,7 +204,7 @@ class ChatManager:
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
self.db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
@@ -216,7 +216,7 @@ class ChatManager:
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})
all_streams = self.db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream

View File

@@ -76,16 +76,16 @@ class EmojiManager:
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if 'emoji' not in self.db.db.list_collection_names():
self.db.db.create_collection('emoji')
self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('filename', 1)], unique=True)
if 'emoji' not in self.db.list_collection_names():
self.db.create_collection('emoji')
self.db.emoji.create_index([('embedding', '2dsphere')])
self.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str):
"""记录表情使用次数"""
try:
self._ensure_db()
self.db.db.emoji.update_one(
self.db.emoji.update_one(
{'_id': emoji_id},
{'$inc': {'usage_count': 1}}
)
@@ -119,7 +119,7 @@ class EmojiManager:
try:
# 获取所有表情包
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
if not all_emojis:
logger.warning("数据库中没有任何表情包")
@@ -157,7 +157,7 @@ class EmojiManager:
if selected_emoji and 'path' in selected_emoji:
# 更新使用次数
self.db.db.emoji.update_one(
self.db.emoji.update_one(
{'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}}
)
@@ -239,7 +239,7 @@ class EmojiManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
existing_emoji = self.db['emoji'].find_one({'filename': filename})
description = None
if existing_emoji:
@@ -305,7 +305,7 @@ class EmojiManager:
}
# 保存到emoji数据库
self.db.db['emoji'].insert_one(emoji_record)
self.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}")
@@ -346,7 +346,7 @@ class EmojiManager:
try:
self._ensure_db()
# 获取所有表情包记录
all_emojis = list(self.db.db.emoji.find())
all_emojis = list(self.db.emoji.find())
removed_count = 0
total_count = len(all_emojis)
@@ -354,13 +354,13 @@ class EmojiManager:
try:
if 'path' not in emoji:
logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']})
self.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']})
self.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
@@ -368,7 +368,7 @@ class EmojiManager:
if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录
result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
result = self.db.emoji.delete_one({'_id': emoji['_id']})
if result.deleted_count > 0:
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1
@@ -379,7 +379,7 @@ class EmojiManager:
continue
# 验证清理结果
remaining_count = self.db.db.emoji.count_documents({})
remaining_count = self.db.emoji.count_documents({})
if removed_count > 0:
logger.success(f"已清理 {removed_count} 个失效的表情包记录")
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")

View File

@@ -154,7 +154,7 @@ class ResponseGenerator:
reasoning_content: str,
):
"""保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one(
self.db.reasoning_logs.insert_one(
{
"time": time.time(),
"chat_id": message.chat_stream.stream_id,

View File

@@ -311,7 +311,7 @@ class PromptBuilder:
{"$project": {"content": 1, "similarity": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
results = list(self.db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results:

View File

@@ -168,7 +168,7 @@ class RelationshipManager:
async def load_all_relationships(self):
"""加载所有关系对象"""
db = Database.get_instance()
all_relationships = db.db.relationships.find({})
all_relationships = db.relationships.find({})
for data in all_relationships:
await self.load_relationship(data)
@@ -176,7 +176,7 @@ class RelationshipManager:
"""每5分钟自动保存一次关系数据"""
db = Database.get_instance()
# 获取所有关系记录
all_relationships = db.db.relationships.find({})
all_relationships = db.relationships.find({})
# 依次加载每条记录
for data in all_relationships:
await self.load_relationship(data)
@@ -206,7 +206,7 @@ class RelationshipManager:
saved = relationship.saved
db = Database.get_instance()
db.db.relationships.update_one(
db.relationships.update_one(
{'user_id': user_id, 'platform': platform},
{'$set': {
'platform': platform,

View File

@@ -23,7 +23,7 @@ class MessageStorage:
"detailed_plain_text": message.detailed_plain_text,
"topic": topic,
}
self.db.db.messages.insert_one(message_data)
self.db.messages.insert_one(message_data)
except Exception:
logger.exception("存储消息失败")

View File

@@ -40,20 +40,20 @@ class ImageManager:
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if 'images' not in self.db.db.list_collection_names():
self.db.db.create_collection('images')
if 'images' not in self.db.list_collection_names():
self.db.create_collection('images')
# 创建索引
self.db.db.images.create_index([('hash', 1)], unique=True)
self.db.db.images.create_index([('url', 1)])
self.db.db.images.create_index([('path', 1)])
self.db.images.create_index([('hash', 1)], unique=True)
self.db.images.create_index([('url', 1)])
self.db.images.create_index([('path', 1)])
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.db.list_collection_names():
self.db.db.create_collection('image_descriptions')
if 'image_descriptions' not in self.db.list_collection_names():
self.db.create_collection('image_descriptions')
# 创建索引
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)])
self.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.image_descriptions.create_index([('type', 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -65,7 +65,7 @@ class ImageManager:
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result= self.db.db.image_descriptions.find_one({
result= self.db.image_descriptions.find_one({
'hash': image_hash,
'type': description_type
})
@@ -79,7 +79,7 @@ class ImageManager:
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
self.db.db.image_descriptions.update_one(
self.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type},
{
'$set': {
@@ -121,7 +121,7 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查重
existing = self.db.db.images.find_one({'hash': image_hash})
existing = self.db.images.find_one({'hash': image_hash})
if existing:
return existing['path']
@@ -142,7 +142,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.db.images.insert_one(image_doc)
self.db.images.insert_one(image_doc)
return file_path
@@ -159,7 +159,7 @@ class ImageManager:
"""
try:
# 先查找是否已存在
existing = self.db.db.images.find_one({'url': url})
existing = self.db.images.find_one({'url': url})
if existing:
return existing['path']
@@ -203,7 +203,7 @@ class ImageManager:
Returns:
bool: 是否存在
"""
return self.db.db.images.find_one({'url': url}) is not None
return self.db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
@@ -226,7 +226,7 @@ class ImageManager:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.db.images.find_one({'hash': image_hash}) is not None
return self.db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
@@ -269,7 +269,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
self.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
@@ -326,7 +326,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
self.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True

View File

@@ -96,7 +96,7 @@ class Memory_graph:
dot_data = {
"concept": node
}
self.db.db.store_memory_dots.insert_one(dot_data)
self.db.store_memory_dots.insert_one(dot_data)
@property
def dots(self):
@@ -106,7 +106,7 @@ class Memory_graph:
def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
@@ -115,7 +115,7 @@ class Memory_graph:
group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同
chat_record = list(
self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
length))
for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
@@ -130,34 +130,34 @@ class Memory_graph:
def save_graph_to_db(self):
# 清空现有的图数据
self.db.db.graph_data.delete_many({})
self.db.graph_data.delete_many({})
# 保存节点
for node in self.G.nodes(data=True):
node_data = {
'concept': node[0],
'memory_items': node[1].get('memory_items', []) # 默认为空列表
}
self.db.db.graph_data.nodes.insert_one(node_data)
self.db.graph_data.nodes.insert_one(node_data)
# 保存边
for edge in self.G.edges():
edge_data = {
'source': edge[0],
'target': edge[1]
}
self.db.db.graph_data.edges.insert_one(edge_data)
self.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self):
# 清空当前图
self.G.clear()
# 加载节点
nodes = self.db.db.graph_data.nodes.find()
nodes = self.db.graph_data.nodes.find()
for node in nodes:
memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边
edges = self.db.db.graph_data.edges.find()
edges = self.db.graph_data.edges.find()
for edge in edges:
self.G.add_edge(edge['source'], edge['target'])

View File

@@ -892,15 +892,6 @@ config = driver.config
start_time = time.time()
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
# 创建记忆图
memory_graph = Memory_graph()
# 创建海马体

View File

@@ -41,10 +41,10 @@ class LLM_request:
"""初始化数据库集合"""
try:
# 创建llm_usage集合的索引
self.db.db.llm_usage.create_index([("timestamp", 1)])
self.db.db.llm_usage.create_index([("model_name", 1)])
self.db.db.llm_usage.create_index([("user_id", 1)])
self.db.db.llm_usage.create_index([("request_type", 1)])
self.db.llm_usage.create_index([("timestamp", 1)])
self.db.llm_usage.create_index([("model_name", 1)])
self.db.llm_usage.create_index([("user_id", 1)])
self.db.llm_usage.create_index([("request_type", 1)])
except Exception:
logger.error("创建数据库索引失败")
@@ -73,7 +73,7 @@ class LLM_request:
"status": "success",
"timestamp": datetime.now()
}
self.db.db.llm_usage.insert_one(usage_data)
self.db.llm_usage.insert_one(usage_data)
logger.info(
f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "

View File

@@ -14,16 +14,6 @@ from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class ScheduleGenerator:
def __init__(self):
# 根据global_config.llm_normal这一字典配置指定模型
@@ -56,7 +46,7 @@ class ScheduleGenerator:
schedule_text = str
existing_schedule = self.db.db.schedule.find_one({"date": date_str})
existing_schedule = self.db.schedule.find_one({"date": date_str})
if existing_schedule:
logger.debug(f"{date_str}的日程已存在:")
schedule_text = existing_schedule["schedule"]
@@ -73,7 +63,7 @@ class ScheduleGenerator:
try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
self.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
except Exception as e:
logger.error(f"生成日程失败: {str(e)}")
schedule_text = "生成日程时出错了"
@@ -153,7 +143,7 @@ class ScheduleGenerator:
"""打印完整的日程安排"""
if not self._parse_schedule(self.today_schedule_text):
logger.warning("今日日程有误,将在下次运行时重新生成")
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
self.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else:
logger.info("=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items():

View File

@@ -53,7 +53,7 @@ class LLMStatistics:
"costs_by_model": defaultdict(float)
}
cursor = self.db.db.llm_usage.find({
cursor = self.db.llm_usage.find({
"timestamp": {"$gte": start_time}
})