Merge pull request #237 from Rikki-Zero/debug

refactor: 修复database单例多次初始化的问题,改变instance默认返回实例的类型,缩短db相关函数调用时的object名
This commit is contained in:
AL76
2025-03-12 02:48:00 +08:00
committed by GitHub
16 changed files with 105 additions and 101 deletions

15
bot.py
View File

@@ -12,6 +12,8 @@ from loguru import logger
from nonebot.adapters.onebot.v11 import Adapter from nonebot.adapters.onebot.v11 import Adapter
import platform import platform
from src.common.database import Database
# 获取没有加载env时的环境变量 # 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ} env_mask = {key: os.getenv(key) for key in os.environ}
@@ -96,6 +98,17 @@ def load_env():
logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
def init_database():
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
def load_logger(): def load_logger():
logger.remove() # 移除默认配置 logger.remove() # 移除默认配置
@@ -198,6 +211,7 @@ def raw_main():
init_config() init_config()
init_env() init_env()
load_env() load_env()
init_database() # 加载完成环境后初始化database
load_logger() load_logger()
env_config = {key: os.getenv(key) for key in os.environ} env_config = {key: os.getenv(key) for key in os.environ}
@@ -223,7 +237,6 @@ def raw_main():
if __name__ == "__main__": if __name__ == "__main__":
try: try:
raw_main() raw_main()

View File

@@ -1,5 +1,6 @@
from typing import Optional from typing import Optional
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.database import Database as MongoDatabase
class Database: class Database:
_instance: Optional["Database"] = None _instance: Optional["Database"] = None
@@ -25,7 +26,7 @@ class Database:
else: else:
# 否则使用无认证连接 # 否则使用无认证连接
self.client = MongoClient(host, port) self.client = MongoClient(host, port)
self.db = self.client[db_name] self.db: MongoDatabase = self.client[db_name]
@classmethod @classmethod
def initialize( def initialize(
@@ -37,15 +38,36 @@ class Database:
password: Optional[str] = None, password: Optional[str] = None,
auth_source: Optional[str] = None, auth_source: Optional[str] = None,
uri: Optional[str] = None, uri: Optional[str] = None,
) -> "Database": ) -> MongoDatabase:
if cls._instance is None: if cls._instance is None:
cls._instance = cls( cls._instance = cls(
host, port, db_name, username, password, auth_source, uri host, port, db_name, username, password, auth_source, uri
) )
return cls._instance return cls._instance.db
@classmethod @classmethod
def get_instance(cls) -> "Database": def get_instance(cls) -> MongoDatabase:
if cls._instance is None: if cls._instance is None:
raise RuntimeError("Database not initialized") raise RuntimeError("Database not initialized")
return cls._instance return cls._instance.db
#测试用
def get_random_group_messages(self, group_id: str, limit: int = 5):
# 先随机获取一条消息
random_message = list(self.db.messages.aggregate([
{"$match": {"group_id": group_id}},
{"$sample": {"size": 1}}
]))[0]
# 获取该消息之后的消息
subsequent_messages = list(self.db.messages.find({
"group_id": group_id,
"time": {"$gt": random_message["time"]}
}).sort("time", 1).limit(limit))
# 将随机消息和后续消息合并
messages = [random_message] + subsequent_messages
return messages

View File

@@ -46,7 +46,7 @@ class ReasoningGUI:
# 初始化数据库连接 # 初始化数据库连接
try: try:
self.db = Database.get_instance().db self.db = Database.get_instance()
logger.success("数据库连接成功") logger.success("数据库连接成功")
except RuntimeError: except RuntimeError:
logger.warning("数据库未初始化,正在尝试初始化...") logger.warning("数据库未初始化,正在尝试初始化...")
@@ -60,7 +60,7 @@ class ReasoningGUI:
password=os.getenv("MONGODB_PASSWORD"), password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"), auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
) )
self.db = Database.get_instance().db self.db = Database.get_instance()
logger.success("数据库初始化成功") logger.success("数据库初始化成功")
except Exception: except Exception:
logger.exception("数据库初始化失败") logger.exception("数据库初始化失败")

View File

@@ -32,18 +32,6 @@ _message_manager_started = False
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
logger.success("初始化数据库成功")
# 初始化表情管理器 # 初始化表情管理器
emoji_manager.initialize() emoji_manager.initialize()

View File

@@ -111,11 +111,11 @@ class ChatManager:
def _ensure_collection(self): def _ensure_collection(self):
"""确保数据库集合存在并创建索引""" """确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.db.list_collection_names(): if "chat_streams" not in self.db.list_collection_names():
self.db.db.create_collection("chat_streams") self.db.create_collection("chat_streams")
# 创建索引 # 创建索引
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True) self.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index( self.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)] [("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
) )
@@ -168,7 +168,7 @@ class ChatManager:
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({"stream_id": stream_id}) data = self.db.chat_streams.find_one({"stream_id": stream_id})
if data: if data:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息 # 更新用户信息和群组信息
@@ -204,7 +204,7 @@ class ChatManager:
async def _save_stream(self, stream: ChatStream): async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库""" """保存聊天流到数据库"""
if not stream.saved: if not stream.saved:
self.db.db.chat_streams.update_one( self.db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True {"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
) )
stream.saved = True stream.saved = True
@@ -216,7 +216,7 @@ class ChatManager:
async def load_all_streams(self): async def load_all_streams(self):
"""从数据库加载所有聊天流""" """从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({}) all_streams = self.db.chat_streams.find({})
for data in all_streams: for data in all_streams:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream self.streams[stream.stream_id] = stream

View File

@@ -76,16 +76,16 @@ class EmojiManager:
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
""" """
if 'emoji' not in self.db.db.list_collection_names(): if 'emoji' not in self.db.list_collection_names():
self.db.db.create_collection('emoji') self.db.create_collection('emoji')
self.db.db.emoji.create_index([('embedding', '2dsphere')]) self.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('filename', 1)], unique=True) self.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str): def record_usage(self, emoji_id: str):
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
self._ensure_db() self._ensure_db()
self.db.db.emoji.update_one( self.db.emoji.update_one(
{'_id': emoji_id}, {'_id': emoji_id},
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
@@ -119,7 +119,7 @@ class EmojiManager:
try: try:
# 获取所有表情包 # 获取所有表情包
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
if not all_emojis: if not all_emojis:
logger.warning("数据库中没有任何表情包") logger.warning("数据库中没有任何表情包")
@@ -157,7 +157,7 @@ class EmojiManager:
if selected_emoji and 'path' in selected_emoji: if selected_emoji and 'path' in selected_emoji:
# 更新使用次数 # 更新使用次数
self.db.db.emoji.update_one( self.db.emoji.update_one(
{'_id': selected_emoji['_id']}, {'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
@@ -239,7 +239,7 @@ class EmojiManager:
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过 # 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename}) existing_emoji = self.db['emoji'].find_one({'filename': filename})
description = None description = None
if existing_emoji: if existing_emoji:
@@ -305,7 +305,7 @@ class EmojiManager:
} }
# 保存到emoji数据库 # 保存到emoji数据库
self.db.db['emoji'].insert_one(emoji_record) self.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}") logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}") logger.info(f"描述: {description}")
@@ -346,7 +346,7 @@ class EmojiManager:
try: try:
self._ensure_db() self._ensure_db()
# 获取所有表情包记录 # 获取所有表情包记录
all_emojis = list(self.db.db.emoji.find()) all_emojis = list(self.db.emoji.find())
removed_count = 0 removed_count = 0
total_count = len(all_emojis) total_count = len(all_emojis)
@@ -354,13 +354,13 @@ class EmojiManager:
try: try:
if 'path' not in emoji: if 'path' not in emoji:
logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}") logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']}) self.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1 removed_count += 1
continue continue
if 'embedding' not in emoji: if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}") logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']}) self.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1 removed_count += 1
continue continue
@@ -368,7 +368,7 @@ class EmojiManager:
if not os.path.exists(emoji['path']): if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}") logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录 # 从数据库中删除记录
result = self.db.db.emoji.delete_one({'_id': emoji['_id']}) result = self.db.emoji.delete_one({'_id': emoji['_id']})
if result.deleted_count > 0: if result.deleted_count > 0:
logger.debug(f"成功删除数据库记录: {emoji['_id']}") logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1 removed_count += 1
@@ -379,7 +379,7 @@ class EmojiManager:
continue continue
# 验证清理结果 # 验证清理结果
remaining_count = self.db.db.emoji.count_documents({}) remaining_count = self.db.emoji.count_documents({})
if removed_count > 0: if removed_count > 0:
logger.success(f"已清理 {removed_count} 个失效的表情包记录") logger.success(f"已清理 {removed_count} 个失效的表情包记录")
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}") logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")

View File

@@ -154,7 +154,7 @@ class ResponseGenerator:
reasoning_content: str, reasoning_content: str,
): ):
"""保存对话记录到数据库""" """保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one( self.db.reasoning_logs.insert_one(
{ {
"time": time.time(), "time": time.time(),
"chat_id": message.chat_stream.stream_id, "chat_id": message.chat_stream.stream_id,

View File

@@ -311,7 +311,7 @@ class PromptBuilder:
{"$project": {"content": 1, "similarity": 1}} {"$project": {"content": 1, "similarity": 1}}
] ]
results = list(self.db.db.knowledges.aggregate(pipeline)) results = list(self.db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results: if not results:

View File

@@ -168,7 +168,7 @@ class RelationshipManager:
async def load_all_relationships(self): async def load_all_relationships(self):
"""加载所有关系对象""" """加载所有关系对象"""
db = Database.get_instance() db = Database.get_instance()
all_relationships = db.db.relationships.find({}) all_relationships = db.relationships.find({})
for data in all_relationships: for data in all_relationships:
await self.load_relationship(data) await self.load_relationship(data)
@@ -176,7 +176,7 @@ class RelationshipManager:
"""每5分钟自动保存一次关系数据""" """每5分钟自动保存一次关系数据"""
db = Database.get_instance() db = Database.get_instance()
# 获取所有关系记录 # 获取所有关系记录
all_relationships = db.db.relationships.find({}) all_relationships = db.relationships.find({})
# 依次加载每条记录 # 依次加载每条记录
for data in all_relationships: for data in all_relationships:
await self.load_relationship(data) await self.load_relationship(data)
@@ -206,7 +206,7 @@ class RelationshipManager:
saved = relationship.saved saved = relationship.saved
db = Database.get_instance() db = Database.get_instance()
db.db.relationships.update_one( db.relationships.update_one(
{'user_id': user_id, 'platform': platform}, {'user_id': user_id, 'platform': platform},
{'$set': { {'$set': {
'platform': platform, 'platform': platform,

View File

@@ -23,7 +23,7 @@ class MessageStorage:
"detailed_plain_text": message.detailed_plain_text, "detailed_plain_text": message.detailed_plain_text,
"topic": topic, "topic": topic,
} }
self.db.db.messages.insert_one(message_data) self.db.messages.insert_one(message_data)
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")

View File

@@ -40,20 +40,20 @@ class ImageManager:
def _ensure_image_collection(self): def _ensure_image_collection(self):
"""确保images集合存在并创建索引""" """确保images集合存在并创建索引"""
if 'images' not in self.db.db.list_collection_names(): if 'images' not in self.db.list_collection_names():
self.db.db.create_collection('images') self.db.create_collection('images')
# 创建索引 # 创建索引
self.db.db.images.create_index([('hash', 1)], unique=True) self.db.images.create_index([('hash', 1)], unique=True)
self.db.db.images.create_index([('url', 1)]) self.db.images.create_index([('url', 1)])
self.db.db.images.create_index([('path', 1)]) self.db.images.create_index([('path', 1)])
def _ensure_description_collection(self): def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引""" """确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.db.list_collection_names(): if 'image_descriptions' not in self.db.list_collection_names():
self.db.db.create_collection('image_descriptions') self.db.create_collection('image_descriptions')
# 创建索引 # 创建索引
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True) self.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)]) self.db.image_descriptions.create_index([('type', 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述 """从数据库获取图片描述
@@ -65,7 +65,7 @@ class ImageManager:
Returns: Returns:
Optional[str]: 描述文本如果不存在则返回None Optional[str]: 描述文本如果不存在则返回None
""" """
result= self.db.db.image_descriptions.find_one({ result= self.db.image_descriptions.find_one({
'hash': image_hash, 'hash': image_hash,
'type': description_type 'type': description_type
}) })
@@ -79,7 +79,7 @@ class ImageManager:
description: 描述文本 description: 描述文本
description_type: 描述类型 ('emoji''image') description_type: 描述类型 ('emoji''image')
""" """
self.db.db.image_descriptions.update_one( self.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type}, {'hash': image_hash, 'type': description_type},
{ {
'$set': { '$set': {
@@ -121,7 +121,7 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
# 查重 # 查重
existing = self.db.db.images.find_one({'hash': image_hash}) existing = self.db.images.find_one({'hash': image_hash})
if existing: if existing:
return existing['path'] return existing['path']
@@ -142,7 +142,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
self.db.db.images.insert_one(image_doc) self.db.images.insert_one(image_doc)
return file_path return file_path
@@ -159,7 +159,7 @@ class ImageManager:
""" """
try: try:
# 先查找是否已存在 # 先查找是否已存在
existing = self.db.db.images.find_one({'url': url}) existing = self.db.images.find_one({'url': url})
if existing: if existing:
return existing['path'] return existing['path']
@@ -203,7 +203,7 @@ class ImageManager:
Returns: Returns:
bool: 是否存在 bool: 是否存在
""" """
return self.db.db.images.find_one({'url': url}) is not None return self.db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool: def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在 """检查图像是否已存在
@@ -226,7 +226,7 @@ class ImageManager:
return False return False
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.db.images.find_one({'hash': image_hash}) is not None return self.db.images.find_one({'hash': image_hash}) is not None
except Exception as e: except Exception as e:
logger.error(f"检查哈希失败: {str(e)}") logger.error(f"检查哈希失败: {str(e)}")
@@ -269,7 +269,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
self.db.db.images.update_one( self.db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True
@@ -326,7 +326,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
self.db.db.images.update_one( self.db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True

View File

@@ -96,7 +96,7 @@ class Memory_graph:
dot_data = { dot_data = {
"concept": node "concept": node
} }
self.db.db.store_memory_dots.insert_one(dot_data) self.db.store_memory_dots.insert_one(dot_data)
@property @property
def dots(self): def dots(self):
@@ -106,7 +106,7 @@ class Memory_graph:
def get_random_chat_from_db(self, length: int, timestamp: str): def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录 # 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = '' chat_text = ''
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
logger.info( logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
@@ -115,7 +115,7 @@ class Memory_graph:
group_id = closest_record['group_id'] # 获取groupid group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同 # 获取该时间戳之后的length条消息且groupid相同
chat_record = list( chat_record = list(
self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit( self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
length)) length))
for record in chat_record: for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
@@ -130,34 +130,34 @@ class Memory_graph:
def save_graph_to_db(self): def save_graph_to_db(self):
# 清空现有的图数据 # 清空现有的图数据
self.db.db.graph_data.delete_many({}) self.db.graph_data.delete_many({})
# 保存节点 # 保存节点
for node in self.G.nodes(data=True): for node in self.G.nodes(data=True):
node_data = { node_data = {
'concept': node[0], 'concept': node[0],
'memory_items': node[1].get('memory_items', []) # 默认为空列表 'memory_items': node[1].get('memory_items', []) # 默认为空列表
} }
self.db.db.graph_data.nodes.insert_one(node_data) self.db.graph_data.nodes.insert_one(node_data)
# 保存边 # 保存边
for edge in self.G.edges(): for edge in self.G.edges():
edge_data = { edge_data = {
'source': edge[0], 'source': edge[0],
'target': edge[1] 'target': edge[1]
} }
self.db.db.graph_data.edges.insert_one(edge_data) self.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self): def load_graph_from_db(self):
# 清空当前图 # 清空当前图
self.G.clear() self.G.clear()
# 加载节点 # 加载节点
nodes = self.db.db.graph_data.nodes.find() nodes = self.db.graph_data.nodes.find()
for node in nodes: for node in nodes:
memory_items = node.get('memory_items', []) memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items) self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边 # 加载边
edges = self.db.db.graph_data.edges.find() edges = self.db.graph_data.edges.find()
for edge in edges: for edge in edges:
self.G.add_edge(edge['source'], edge['target']) self.G.add_edge(edge['source'], edge['target'])

View File

@@ -892,15 +892,6 @@ config = driver.config
start_time = time.time() start_time = time.time()
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
# 创建记忆图 # 创建记忆图
memory_graph = Memory_graph() memory_graph = Memory_graph()
# 创建海马体 # 创建海马体

View File

@@ -41,10 +41,10 @@ class LLM_request:
"""初始化数据库集合""" """初始化数据库集合"""
try: try:
# 创建llm_usage集合的索引 # 创建llm_usage集合的索引
self.db.db.llm_usage.create_index([("timestamp", 1)]) self.db.llm_usage.create_index([("timestamp", 1)])
self.db.db.llm_usage.create_index([("model_name", 1)]) self.db.llm_usage.create_index([("model_name", 1)])
self.db.db.llm_usage.create_index([("user_id", 1)]) self.db.llm_usage.create_index([("user_id", 1)])
self.db.db.llm_usage.create_index([("request_type", 1)]) self.db.llm_usage.create_index([("request_type", 1)])
except Exception: except Exception:
logger.error("创建数据库索引失败") logger.error("创建数据库索引失败")
@@ -73,7 +73,7 @@ class LLM_request:
"status": "success", "status": "success",
"timestamp": datetime.now() "timestamp": datetime.now()
} }
self.db.db.llm_usage.insert_one(usage_data) self.db.llm_usage.insert_one(usage_data)
logger.info( logger.info(
f"Token使用情况 - 模型: {self.model_name}, " f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, " f"用户: {user_id}, 类型: {request_type}, "

View File

@@ -14,16 +14,6 @@ from ..models.utils_model import LLM_request
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class ScheduleGenerator: class ScheduleGenerator:
def __init__(self): def __init__(self):
# 根据global_config.llm_normal这一字典配置指定模型 # 根据global_config.llm_normal这一字典配置指定模型
@@ -56,7 +46,7 @@ class ScheduleGenerator:
schedule_text = str schedule_text = str
existing_schedule = self.db.db.schedule.find_one({"date": date_str}) existing_schedule = self.db.schedule.find_one({"date": date_str})
if existing_schedule: if existing_schedule:
logger.debug(f"{date_str}的日程已存在:") logger.debug(f"{date_str}的日程已存在:")
schedule_text = existing_schedule["schedule"] schedule_text = existing_schedule["schedule"]
@@ -73,7 +63,7 @@ class ScheduleGenerator:
try: try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt) schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) self.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
except Exception as e: except Exception as e:
logger.error(f"生成日程失败: {str(e)}") logger.error(f"生成日程失败: {str(e)}")
schedule_text = "生成日程时出错了" schedule_text = "生成日程时出错了"
@@ -153,7 +143,7 @@ class ScheduleGenerator:
"""打印完整的日程安排""" """打印完整的日程安排"""
if not self._parse_schedule(self.today_schedule_text): if not self._parse_schedule(self.today_schedule_text):
logger.warning("今日日程有误,将在下次运行时重新生成") logger.warning("今日日程有误,将在下次运行时重新生成")
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) self.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else: else:
logger.info("=== 今日日程安排 ===") logger.info("=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items(): for time_str, activity in self.today_schedule.items():

View File

@@ -53,7 +53,7 @@ class LLMStatistics:
"costs_by_model": defaultdict(float) "costs_by_model": defaultdict(float)
} }
cursor = self.db.db.llm_usage.find({ cursor = self.db.llm_usage.find({
"timestamp": {"$gte": start_time} "timestamp": {"$gte": start_time}
}) })