diff --git a/bot.py b/bot.py index 36d621a6e..8d51cee3c 100644 --- a/bot.py +++ b/bot.py @@ -12,6 +12,8 @@ from loguru import logger from nonebot.adapters.onebot.v11 import Adapter import platform +from src.common.database import Database + # 获取没有加载env时的环境变量 env_mask = {key: os.getenv(key) for key in os.environ} @@ -96,6 +98,17 @@ def load_env(): logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") +def init_database(): + Database.initialize( + uri=os.getenv("MONGODB_URI"), + host=os.getenv("MONGODB_HOST", "127.0.0.1"), + port=int(os.getenv("MONGODB_PORT", "27017")), + db_name=os.getenv("DATABASE_NAME", "MegBot"), + username=os.getenv("MONGODB_USERNAME"), + password=os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE"), + ) + def load_logger(): logger.remove() # 移除默认配置 @@ -198,6 +211,7 @@ def raw_main(): init_config() init_env() load_env() + init_database() # 加载完成环境后初始化database load_logger() env_config = {key: os.getenv(key) for key in os.environ} @@ -223,7 +237,6 @@ def raw_main(): if __name__ == "__main__": - try: raw_main() diff --git a/src/common/database.py b/src/common/database.py index f0954b07c..9d9a596d1 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -1,7 +1,7 @@ from typing import Optional from pymongo import MongoClient - +from pymongo.database import Database as MongoDatabase class Database: _instance: Optional["Database"] = None @@ -27,7 +27,7 @@ class Database: else: # 否则使用无认证连接 self.client = MongoClient(host, port) - self.db = self.client[db_name] + self.db: MongoDatabase = self.client[db_name] @classmethod def initialize( @@ -39,18 +39,18 @@ class Database: password: Optional[str] = None, auth_source: Optional[str] = None, uri: Optional[str] = None, - ) -> "Database": + ) -> MongoDatabase: if cls._instance is None: cls._instance = cls( host, port, db_name, username, password, auth_source, uri ) - return cls._instance + return cls._instance.db @classmethod - def get_instance(cls) -> "Database": + def get_instance(cls) -> MongoDatabase: if cls._instance is None: raise RuntimeError("Database not initialized") - return cls._instance + return cls._instance.db #测试用 diff --git a/src/gui/reasoning_gui.py b/src/gui/reasoning_gui.py index e131658b8..84b95adaf 100644 --- a/src/gui/reasoning_gui.py +++ b/src/gui/reasoning_gui.py @@ -46,7 +46,7 @@ class ReasoningGUI: # 初始化数据库连接 try: - self.db = Database.get_instance().db + self.db = Database.get_instance() logger.success("数据库连接成功") except RuntimeError: logger.warning("数据库未初始化,正在尝试初始化...") @@ -60,7 +60,7 @@ class ReasoningGUI: password=os.getenv("MONGODB_PASSWORD"), auth_source=os.getenv("MONGODB_AUTH_SOURCE"), ) - self.db = Database.get_instance().db + self.db = Database.get_instance() logger.success("数据库初始化成功") except Exception: logger.exception("数据库初始化失败") diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index ec3d4f01d..8ae525708 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -32,18 +32,6 @@ _message_manager_started = False driver = get_driver() config = driver.config -Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), -) -logger.success("初始化数据库成功") - - # 初始化表情管理器 emoji_manager.initialize() diff --git a/src/plugins/chat/chat_stream.py b/src/plugins/chat/chat_stream.py index bee679173..3ccd03f81 100644 --- a/src/plugins/chat/chat_stream.py +++ b/src/plugins/chat/chat_stream.py @@ -111,11 +111,11 @@ class ChatManager: def _ensure_collection(self): """确保数据库集合存在并创建索引""" - if "chat_streams" not in self.db.db.list_collection_names(): - self.db.db.create_collection("chat_streams") + if "chat_streams" not in self.db.list_collection_names(): + self.db.create_collection("chat_streams") # 创建索引 - self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True) - self.db.db.chat_streams.create_index( + self.db.chat_streams.create_index([("stream_id", 1)], unique=True) + self.db.chat_streams.create_index( [("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)] ) @@ -168,7 +168,7 @@ class ChatManager: return stream # 检查数据库中是否存在 - data = self.db.db.chat_streams.find_one({"stream_id": stream_id}) + data = self.db.chat_streams.find_one({"stream_id": stream_id}) if data: stream = ChatStream.from_dict(data) # 更新用户信息和群组信息 @@ -204,7 +204,7 @@ class ChatManager: async def _save_stream(self, stream: ChatStream): """保存聊天流到数据库""" if not stream.saved: - self.db.db.chat_streams.update_one( + self.db.chat_streams.update_one( {"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True ) stream.saved = True @@ -216,7 +216,7 @@ class ChatManager: async def load_all_streams(self): """从数据库加载所有聊天流""" - all_streams = self.db.db.chat_streams.find({}) + all_streams = self.db.chat_streams.find({}) for data in all_streams: stream = ChatStream.from_dict(data) self.streams[stream.stream_id] = stream diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 3adb952d3..1743571e9 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -76,16 +76,16 @@ class EmojiManager: 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 """ - if 'emoji' not in self.db.db.list_collection_names(): - self.db.db.create_collection('emoji') - self.db.db.emoji.create_index([('embedding', '2dsphere')]) - self.db.db.emoji.create_index([('filename', 1)], unique=True) + if 'emoji' not in self.db.list_collection_names(): + self.db.create_collection('emoji') + self.db.emoji.create_index([('embedding', '2dsphere')]) + self.db.emoji.create_index([('filename', 1)], unique=True) def record_usage(self, emoji_id: str): """记录表情使用次数""" try: self._ensure_db() - self.db.db.emoji.update_one( + self.db.emoji.update_one( {'_id': emoji_id}, {'$inc': {'usage_count': 1}} ) @@ -119,7 +119,7 @@ class EmojiManager: try: # 获取所有表情包 - all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) + all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) if not all_emojis: logger.warning("数据库中没有任何表情包") @@ -157,7 +157,7 @@ class EmojiManager: if selected_emoji and 'path' in selected_emoji: # 更新使用次数 - self.db.db.emoji.update_one( + self.db.emoji.update_one( {'_id': selected_emoji['_id']}, {'$inc': {'usage_count': 1}} ) @@ -236,7 +236,7 @@ class EmojiManager: image_hash = hashlib.md5(image_bytes).hexdigest() # 检查是否已经注册过 - existing_emoji = self.db.db['emoji'].find_one({'filename': filename}) + existing_emoji = self.db['emoji'].find_one({'filename': filename}) description = None if existing_emoji: @@ -298,7 +298,7 @@ class EmojiManager: } # 保存到emoji数据库 - self.db.db['emoji'].insert_one(emoji_record) + self.db['emoji'].insert_one(emoji_record) logger.success(f"注册新表情包: {filename}") logger.info(f"描述: {description}") @@ -338,7 +338,7 @@ class EmojiManager: try: self._ensure_db() # 获取所有表情包记录 - all_emojis = list(self.db.db.emoji.find()) + all_emojis = list(self.db.emoji.find()) removed_count = 0 total_count = len(all_emojis) @@ -346,13 +346,13 @@ class EmojiManager: try: if 'path' not in emoji: logger.warning(f"发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}") - self.db.db.emoji.delete_one({'_id': emoji['_id']}) + self.db.emoji.delete_one({'_id': emoji['_id']}) removed_count += 1 continue if 'embedding' not in emoji: logger.warning(f"发现过时记录(缺少embedding字段),ID: {emoji.get('_id', 'unknown')}") - self.db.db.emoji.delete_one({'_id': emoji['_id']}) + self.db.emoji.delete_one({'_id': emoji['_id']}) removed_count += 1 continue @@ -360,7 +360,7 @@ class EmojiManager: if not os.path.exists(emoji['path']): logger.warning(f"表情包文件已被删除: {emoji['path']}") # 从数据库中删除记录 - result = self.db.db.emoji.delete_one({'_id': emoji['_id']}) + result = self.db.emoji.delete_one({'_id': emoji['_id']}) if result.deleted_count > 0: logger.debug(f"成功删除数据库记录: {emoji['_id']}") removed_count += 1 @@ -371,7 +371,7 @@ class EmojiManager: continue # 验证清理结果 - remaining_count = self.db.db.emoji.count_documents({}) + remaining_count = self.db.emoji.count_documents({}) if removed_count > 0: logger.success(f"已清理 {removed_count} 个失效的表情包记录") logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}") diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index af7334afe..285ea59b7 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -154,7 +154,7 @@ class ResponseGenerator: reasoning_content: str, ): """保存对话记录到数据库""" - self.db.db.reasoning_logs.insert_one( + self.db.reasoning_logs.insert_one( { "time": time.time(), "chat_id": message.chat_stream.stream_id, diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index fec6c7926..ea3da777a 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -311,7 +311,7 @@ class PromptBuilder: {"$project": {"content": 1, "similarity": 1}} ] - results = list(self.db.db.knowledges.aggregate(pipeline)) + results = list(self.db.knowledges.aggregate(pipeline)) # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") if not results: diff --git a/src/plugins/chat/relationship_manager.py b/src/plugins/chat/relationship_manager.py index 9e7cafda0..baebb1fe8 100644 --- a/src/plugins/chat/relationship_manager.py +++ b/src/plugins/chat/relationship_manager.py @@ -169,7 +169,7 @@ class RelationshipManager: async def load_all_relationships(self): """加载所有关系对象""" db = Database.get_instance() - all_relationships = db.db.relationships.find({}) + all_relationships = db.relationships.find({}) for data in all_relationships: await self.load_relationship(data) @@ -177,7 +177,7 @@ class RelationshipManager: """每5分钟自动保存一次关系数据""" db = Database.get_instance() # 获取所有关系记录 - all_relationships = db.db.relationships.find({}) + all_relationships = db.relationships.find({}) # 依次加载每条记录 for data in all_relationships: await self.load_relationship(data) @@ -207,7 +207,7 @@ class RelationshipManager: saved = relationship.saved db = Database.get_instance() - db.db.relationships.update_one( + db.relationships.update_one( {'user_id': user_id, 'platform': platform}, {'$set': { 'platform': platform, diff --git a/src/plugins/chat/storage.py b/src/plugins/chat/storage.py index f403b2c8b..dc03e4ced 100644 --- a/src/plugins/chat/storage.py +++ b/src/plugins/chat/storage.py @@ -25,7 +25,7 @@ class MessageStorage: "detailed_plain_text": message.detailed_plain_text, "topic": topic, } - self.db.db.messages.insert_one(message_data) + self.db.messages.insert_one(message_data) except Exception: logger.exception("存储消息失败") diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 25f23359b..ddb3b04cd 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -44,20 +44,20 @@ class ImageManager: def _ensure_image_collection(self): """确保images集合存在并创建索引""" - if 'images' not in self.db.db.list_collection_names(): - self.db.db.create_collection('images') + if 'images' not in self.db.list_collection_names(): + self.db.create_collection('images') # 创建索引 - self.db.db.images.create_index([('hash', 1)], unique=True) - self.db.db.images.create_index([('url', 1)]) - self.db.db.images.create_index([('path', 1)]) + self.db.images.create_index([('hash', 1)], unique=True) + self.db.images.create_index([('url', 1)]) + self.db.images.create_index([('path', 1)]) def _ensure_description_collection(self): """确保image_descriptions集合存在并创建索引""" - if 'image_descriptions' not in self.db.db.list_collection_names(): - self.db.db.create_collection('image_descriptions') + if 'image_descriptions' not in self.db.list_collection_names(): + self.db.create_collection('image_descriptions') # 创建索引 - self.db.db.image_descriptions.create_index([('hash', 1)], unique=True) - self.db.db.image_descriptions.create_index([('type', 1)]) + self.db.image_descriptions.create_index([('hash', 1)], unique=True) + self.db.image_descriptions.create_index([('type', 1)]) def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 @@ -69,7 +69,7 @@ class ImageManager: Returns: Optional[str]: 描述文本,如果不存在则返回None """ - result= self.db.db.image_descriptions.find_one({ + result= self.db.image_descriptions.find_one({ 'hash': image_hash, 'type': description_type }) @@ -83,7 +83,7 @@ class ImageManager: description: 描述文本 description_type: 描述类型 ('emoji' 或 'image') """ - self.db.db.image_descriptions.update_one( + self.db.image_descriptions.update_one( {'hash': image_hash, 'type': description_type}, { '$set': { @@ -125,7 +125,7 @@ class ImageManager: image_hash = hashlib.md5(image_bytes).hexdigest() # 查重 - existing = self.db.db.images.find_one({'hash': image_hash}) + existing = self.db.images.find_one({'hash': image_hash}) if existing: return existing['path'] @@ -146,7 +146,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - self.db.db.images.insert_one(image_doc) + self.db.images.insert_one(image_doc) return file_path @@ -163,7 +163,7 @@ class ImageManager: """ try: # 先查找是否已存在 - existing = self.db.db.images.find_one({'url': url}) + existing = self.db.images.find_one({'url': url}) if existing: return existing['path'] @@ -207,7 +207,7 @@ class ImageManager: Returns: bool: 是否存在 """ - return self.db.db.images.find_one({'url': url}) is not None + return self.db.images.find_one({'url': url}) is not None def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool: """检查图像是否已存在 @@ -230,7 +230,7 @@ class ImageManager: return False image_hash = hashlib.md5(image_bytes).hexdigest() - return self.db.db.images.find_one({'hash': image_hash}) is not None + return self.db.images.find_one({'hash': image_hash}) is not None except Exception as e: logger.error(f"检查哈希失败: {str(e)}") @@ -273,7 +273,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - self.db.db.images.update_one( + self.db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True @@ -330,7 +330,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - self.db.db.images.update_one( + self.db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True diff --git a/src/plugins/knowledege/knowledge_library.py b/src/plugins/knowledege/knowledge_library.py index e9d7167fd..1bebf0930 100644 --- a/src/plugins/knowledege/knowledge_library.py +++ b/src/plugins/knowledege/knowledge_library.py @@ -17,17 +17,6 @@ load_dotenv(env_path) from src.common.database import Database -# 从环境变量获取配置 -Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), -) - class KnowledgeLibrary: def __init__(self): self.db = Database.get_instance() @@ -72,7 +61,7 @@ class KnowledgeLibrary: """处理单个文件""" try: # 检查文件是否已处理 - if self.db.db.processed_files.find_one({"file_path": file_path}): + if self.db.processed_files.find_one({"file_path": file_path}): print(f"文件已处理过,跳过: {file_path}") return @@ -104,14 +93,14 @@ class KnowledgeLibrary: content_hash = hash(segment) # 更新或插入文档 - self.db.db.knowledges.update_one( + self.db.knowledges.update_one( {"content_hash": content_hash}, {"$set": doc}, upsert=True ) # 记录文件已处理 - self.db.db.processed_files.insert_one({ + self.db.processed_files.insert_one({ "file_path": file_path, "processed_time": time.time() }) @@ -178,7 +167,7 @@ class KnowledgeLibrary: {"$project": {"content": 1, "similarity": 1, "file_path": 1}} ] - results = list(self.db.db.knowledges.aggregate(pipeline)) + results = list(self.db.knowledges.aggregate(pipeline)) return results # 创建单例实例 diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py index 9f15164f1..d6ba8f3b2 100644 --- a/src/plugins/memory_system/draw_memory.py +++ b/src/plugins/memory_system/draw_memory.py @@ -96,7 +96,7 @@ class Memory_graph: dot_data = { "concept": node } - self.db.db.store_memory_dots.insert_one(dot_data) + self.db.store_memory_dots.insert_one(dot_data) @property def dots(self): @@ -106,7 +106,7 @@ class Memory_graph: def get_random_chat_from_db(self, length: int, timestamp: str): # 从数据库中根据时间戳获取离其最近的聊天记录 chat_text = '' - closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 + closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 logger.info( f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") @@ -115,7 +115,7 @@ class Memory_graph: group_id = closest_record['group_id'] # 获取groupid # 获取该时间戳之后的length条消息,且groupid相同 chat_record = list( - self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit( + self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit( length)) for record in chat_record: time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) @@ -130,34 +130,34 @@ class Memory_graph: def save_graph_to_db(self): # 清空现有的图数据 - self.db.db.graph_data.delete_many({}) + self.db.graph_data.delete_many({}) # 保存节点 for node in self.G.nodes(data=True): node_data = { 'concept': node[0], 'memory_items': node[1].get('memory_items', []) # 默认为空列表 } - self.db.db.graph_data.nodes.insert_one(node_data) + self.db.graph_data.nodes.insert_one(node_data) # 保存边 for edge in self.G.edges(): edge_data = { 'source': edge[0], 'target': edge[1] } - self.db.db.graph_data.edges.insert_one(edge_data) + self.db.graph_data.edges.insert_one(edge_data) def load_graph_from_db(self): # 清空当前图 self.G.clear() # 加载节点 - nodes = self.db.db.graph_data.nodes.find() + nodes = self.db.graph_data.nodes.find() for node in nodes: memory_items = node.get('memory_items', []) if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] self.G.add_node(node['concept'], memory_items=memory_items) # 加载边 - edges = self.db.db.graph_data.edges.find() + edges = self.db.graph_data.edges.find() for edge in edges: self.G.add_edge(edge['source'], edge['target']) diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index c0b551b58..3c844c3ff 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -887,15 +887,6 @@ config = driver.config start_time = time.time() -Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), -) # 创建记忆图 memory_graph = Memory_graph() # 创建海马体 diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 3424d662c..75b46f611 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -41,10 +41,10 @@ class LLM_request: """初始化数据库集合""" try: # 创建llm_usage集合的索引 - self.db.db.llm_usage.create_index([("timestamp", 1)]) - self.db.db.llm_usage.create_index([("model_name", 1)]) - self.db.db.llm_usage.create_index([("user_id", 1)]) - self.db.db.llm_usage.create_index([("request_type", 1)]) + self.db.llm_usage.create_index([("timestamp", 1)]) + self.db.llm_usage.create_index([("model_name", 1)]) + self.db.llm_usage.create_index([("user_id", 1)]) + self.db.llm_usage.create_index([("request_type", 1)]) except Exception: logger.error("创建数据库索引失败") @@ -73,7 +73,7 @@ class LLM_request: "status": "success", "timestamp": datetime.now() } - self.db.db.llm_usage.insert_one(usage_data) + self.db.llm_usage.insert_one(usage_data) logger.info( f"Token使用情况 - 模型: {self.model_name}, " f"用户: {user_id}, 类型: {request_type}, " diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index 12c6ce3b5..bde593890 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -14,16 +14,6 @@ from ..models.utils_model import LLM_request driver = get_driver() config = driver.config -Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), -) - class ScheduleGenerator: def __init__(self): # 根据global_config.llm_normal这一字典配置指定模型 @@ -56,7 +46,7 @@ class ScheduleGenerator: schedule_text = str - existing_schedule = self.db.db.schedule.find_one({"date": date_str}) + existing_schedule = self.db.schedule.find_one({"date": date_str}) if existing_schedule: logger.debug(f"{date_str}的日程已存在:") schedule_text = existing_schedule["schedule"] @@ -73,7 +63,7 @@ class ScheduleGenerator: try: schedule_text, _ = await self.llm_scheduler.generate_response(prompt) - self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) + self.db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) except Exception as e: logger.error(f"生成日程失败: {str(e)}") schedule_text = "生成日程时出错了" @@ -153,7 +143,7 @@ class ScheduleGenerator: """打印完整的日程安排""" if not self._parse_schedule(self.today_schedule_text): logger.warning("今日日程有误,将在下次运行时重新生成") - self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) + self.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) else: logger.info("=== 今日日程安排 ===") for time_str, activity in self.today_schedule.items(): diff --git a/src/plugins/utils/statistic.py b/src/plugins/utils/statistic.py index 2974389e6..4629f0e0b 100644 --- a/src/plugins/utils/statistic.py +++ b/src/plugins/utils/statistic.py @@ -53,7 +53,7 @@ class LLMStatistics: "costs_by_model": defaultdict(float) } - cursor = self.db.db.llm_usage.find({ + cursor = self.db.llm_usage.find({ "timestamp": {"$gte": start_time} })