diff --git a/bot.py b/bot.py index 19ad80025..48517fe24 100644 --- a/bot.py +++ b/bot.py @@ -12,8 +12,6 @@ from loguru import logger from nonebot.adapters.onebot.v11 import Adapter import platform -from src.common.database import Database - # 获取没有加载env时的环境变量 env_mask = {key: os.getenv(key) for key in os.environ} @@ -111,18 +109,6 @@ def load_env(): logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") -def init_database(): - Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), - ) - - def load_logger(): logger.remove() # 移除默认配置 if os.getenv("ENVIRONMENT") == "dev": @@ -223,7 +209,6 @@ def raw_main(): init_config() init_env() load_env() - init_database() # 加载完成环境后初始化database load_logger() env_config = {key: os.getenv(key) for key in os.environ} diff --git a/src/common/database.py b/src/common/database.py index c6cead225..ca73dc468 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -1,73 +1,53 @@ -from typing import Optional +import os +from typing import cast from pymongo import MongoClient -from pymongo.database import Database as MongoDatabase +from pymongo.database import Database -class Database: - _instance: Optional["Database"] = None - - def __init__( - self, - host: str, - port: int, - db_name: str, - username: Optional[str] = None, - password: Optional[str] = None, - auth_source: Optional[str] = None, - uri: Optional[str] = None, - ): - if uri and uri.startswith("mongodb://"): - # 优先使用URI连接 - self.client = MongoClient(uri) - elif username and password: - # 如果有用户名和密码,使用认证连接 - self.client = MongoClient( - host, port, username=username, password=password, authSource=auth_source - ) - else: - # 否则使用无认证连接 - self.client = MongoClient(host, port) - self.db: MongoDatabase = self.client[db_name] - - @classmethod - def initialize( - cls, - host: str, - port: int, - db_name: str, - username: Optional[str] = None, - password: Optional[str] = None, - auth_source: Optional[str] = None, - uri: Optional[str] = None, - ) -> MongoDatabase: - if cls._instance is None: - cls._instance = cls( - host, port, db_name, username, password, auth_source, uri - ) - return cls._instance.db - - @classmethod - def get_instance(cls) -> MongoDatabase: - if cls._instance is None: - raise RuntimeError("Database not initialized") - return cls._instance.db +_client = None +_db = None - #测试用 - - def get_random_group_messages(self, group_id: str, limit: int = 5): - # 先随机获取一条消息 - random_message = list(self.db.messages.aggregate([ - {"$match": {"group_id": group_id}}, - {"$sample": {"size": 1}} - ]))[0] - - # 获取该消息之后的消息 - subsequent_messages = list(self.db.messages.find({ - "group_id": group_id, - "time": {"$gt": random_message["time"]} - }).sort("time", 1).limit(limit)) - - # 将随机消息和后续消息合并 - messages = [random_message] + subsequent_messages - - return messages \ No newline at end of file +def __create_database_instance(): + uri = os.getenv("MONGODB_URI") + host = os.getenv("MONGODB_HOST", "127.0.0.1") + port = int(os.getenv("MONGODB_PORT", "27017")) + db_name = os.getenv("DATABASE_NAME", "MegBot") + username = os.getenv("MONGODB_USERNAME") + password = os.getenv("MONGODB_PASSWORD") + auth_source = os.getenv("MONGODB_AUTH_SOURCE") + + if uri and uri.startswith("mongodb://"): + # 优先使用URI连接 + return MongoClient(uri) + + if username and password: + # 如果有用户名和密码,使用认证连接 + return MongoClient( + host, port, username=username, password=password, authSource=auth_source + ) + + # 否则使用无认证连接 + return MongoClient(host, port) + + +def get_db(): + """获取数据库连接实例,延迟初始化。""" + global _client, _db + if _client is None: + _client = __create_database_instance() + _db = _client[os.getenv("DATABASE_NAME", "MegBot")] + return _db + + +class DBWrapper: + """数据库代理类,保持接口兼容性同时实现懒加载。""" + + def __getattr__(self, name): + return getattr(get_db(), name) + + def __getitem__(self, key): + return get_db()[key] + + +# 全局数据库访问点 +db: Database = DBWrapper() diff --git a/src/gui/reasoning_gui.py b/src/gui/reasoning_gui.py index 84b95adaf..e79f8f91f 100644 --- a/src/gui/reasoning_gui.py +++ b/src/gui/reasoning_gui.py @@ -7,7 +7,7 @@ from datetime import datetime from typing import Dict, List from loguru import logger from typing import Optional -from ..common.database import Database +from ..common.database import db import customtkinter as ctk from dotenv import load_dotenv @@ -44,28 +44,6 @@ class ReasoningGUI: self.root.geometry('800x600') self.root.protocol("WM_DELETE_WINDOW", self._on_closing) - # 初始化数据库连接 - try: - self.db = Database.get_instance() - logger.success("数据库连接成功") - except RuntimeError: - logger.warning("数据库未初始化,正在尝试初始化...") - try: - Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), - ) - self.db = Database.get_instance() - logger.success("数据库初始化成功") - except Exception: - logger.exception("数据库初始化失败") - sys.exit(1) - # 存储群组数据 self.group_data: Dict[str, List[dict]] = {} @@ -264,11 +242,11 @@ class ReasoningGUI: logger.debug(f"查询条件: {query}") # 先获取一条记录检查时间格式 - sample = self.db.reasoning_logs.find_one() + sample = db.reasoning_logs.find_one() if sample: logger.debug(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}") - cursor = self.db.reasoning_logs.find(query).sort("time", -1) + cursor = db.reasoning_logs.find(query).sort("time", -1) new_data = {} total_count = 0 @@ -333,17 +311,6 @@ class ReasoningGUI: def main(): - """主函数""" - Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), - ) - app = ReasoningGUI() app.run() diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index 1c6bf3f35..d7a7bd7e4 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -7,7 +7,6 @@ from nonebot import get_driver, on_message, require from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment,MessageEvent from nonebot.typing import T_State -from ...common.database import Database from ..moods.moods import MoodManager # 导入情绪管理器 from ..schedule.schedule_generator import bot_schedule from ..utils.statistic import LLMStatistics diff --git a/src/plugins/chat/chat_stream.py b/src/plugins/chat/chat_stream.py index 3ccd03f81..60b0af493 100644 --- a/src/plugins/chat/chat_stream.py +++ b/src/plugins/chat/chat_stream.py @@ -6,7 +6,7 @@ from typing import Dict, Optional from loguru import logger -from ...common.database import Database +from ...common.database import db from .message_base import GroupInfo, UserInfo @@ -83,7 +83,6 @@ class ChatManager: def __init__(self): if not self._initialized: self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream - self.db = Database.get_instance() self._ensure_collection() self._initialized = True # 在事件循环中启动初始化 @@ -111,11 +110,11 @@ class ChatManager: def _ensure_collection(self): """确保数据库集合存在并创建索引""" - if "chat_streams" not in self.db.list_collection_names(): - self.db.create_collection("chat_streams") + if "chat_streams" not in db.list_collection_names(): + db.create_collection("chat_streams") # 创建索引 - self.db.chat_streams.create_index([("stream_id", 1)], unique=True) - self.db.chat_streams.create_index( + db.chat_streams.create_index([("stream_id", 1)], unique=True) + db.chat_streams.create_index( [("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)] ) @@ -168,7 +167,7 @@ class ChatManager: return stream # 检查数据库中是否存在 - data = self.db.chat_streams.find_one({"stream_id": stream_id}) + data = db.chat_streams.find_one({"stream_id": stream_id}) if data: stream = ChatStream.from_dict(data) # 更新用户信息和群组信息 @@ -204,7 +203,7 @@ class ChatManager: async def _save_stream(self, stream: ChatStream): """保存聊天流到数据库""" if not stream.saved: - self.db.chat_streams.update_one( + db.chat_streams.update_one( {"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True ) stream.saved = True @@ -216,7 +215,7 @@ class ChatManager: async def load_all_streams(self): """从数据库加载所有聊天流""" - all_streams = self.db.chat_streams.find({}) + all_streams = db.chat_streams.find({}) for data in all_streams: stream = ChatStream.from_dict(data) self.streams[stream.stream_id] = stream diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 1c8a07699..822eda009 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -12,7 +12,7 @@ import io from loguru import logger from nonebot import get_driver -from ...common.database import Database +from ...common.database import db from ..chat.config import global_config from ..chat.utils import get_embedding from ..chat.utils_image import ImageManager, image_path_to_base64 @@ -30,12 +30,10 @@ class EmojiManager: def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) - cls._instance.db = None cls._instance._initialized = False return cls._instance def __init__(self): - self.db = Database.get_instance() self._scan_task = None self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60, @@ -50,7 +48,6 @@ class EmojiManager: """初始化数据库连接和表情目录""" if not self._initialized: try: - self.db = Database.get_instance() self._ensure_emoji_collection() self._ensure_emoji_dir() self._initialized = True @@ -78,16 +75,16 @@ class EmojiManager: 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 """ - if 'emoji' not in self.db.list_collection_names(): - self.db.create_collection('emoji') - self.db.emoji.create_index([('embedding', '2dsphere')]) - self.db.emoji.create_index([('filename', 1)], unique=True) + if 'emoji' not in db.list_collection_names(): + db.create_collection('emoji') + db.emoji.create_index([('embedding', '2dsphere')]) + db.emoji.create_index([('filename', 1)], unique=True) def record_usage(self, emoji_id: str): """记录表情使用次数""" try: self._ensure_db() - self.db.emoji.update_one( + db.emoji.update_one( {'_id': emoji_id}, {'$inc': {'usage_count': 1}} ) @@ -121,7 +118,7 @@ class EmojiManager: try: # 获取所有表情包 - all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) + all_emojis = list(db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) if not all_emojis: logger.warning("数据库中没有任何表情包") @@ -159,7 +156,7 @@ class EmojiManager: if selected_emoji and 'path' in selected_emoji: # 更新使用次数 - self.db.emoji.update_one( + db.emoji.update_one( {'_id': selected_emoji['_id']}, {'$inc': {'usage_count': 1}} ) @@ -241,14 +238,14 @@ class EmojiManager: image_hash = hashlib.md5(image_bytes).hexdigest() image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # 检查是否已经注册过 - existing_emoji = self.db['emoji'].find_one({'filename': filename}) + existing_emoji = db['emoji'].find_one({'filename': filename}) description = None if existing_emoji: # 即使表情包已存在,也检查是否需要同步到images集合 description = existing_emoji.get('discription') # 检查是否在images集合中存在 - existing_image = image_manager.db.images.find_one({'hash': image_hash}) + existing_image = db.images.find_one({'hash': image_hash}) if not existing_image: # 同步到images集合 image_doc = { @@ -258,7 +255,7 @@ class EmojiManager: 'description': description, 'timestamp': int(time.time()) } - image_manager.db.images.update_one( + db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True @@ -307,7 +304,7 @@ class EmojiManager: } # 保存到emoji数据库 - self.db['emoji'].insert_one(emoji_record) + db['emoji'].insert_one(emoji_record) logger.success(f"注册新表情包: {filename}") logger.info(f"描述: {description}") @@ -320,7 +317,7 @@ class EmojiManager: 'description': description, 'timestamp': int(time.time()) } - image_manager.db.images.update_one( + db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True @@ -348,7 +345,7 @@ class EmojiManager: try: self._ensure_db() # 获取所有表情包记录 - all_emojis = list(self.db.emoji.find()) + all_emojis = list(db.emoji.find()) removed_count = 0 total_count = len(all_emojis) @@ -356,13 +353,13 @@ class EmojiManager: try: if 'path' not in emoji: logger.warning(f"发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}") - self.db.emoji.delete_one({'_id': emoji['_id']}) + db.emoji.delete_one({'_id': emoji['_id']}) removed_count += 1 continue if 'embedding' not in emoji: logger.warning(f"发现过时记录(缺少embedding字段),ID: {emoji.get('_id', 'unknown')}") - self.db.emoji.delete_one({'_id': emoji['_id']}) + db.emoji.delete_one({'_id': emoji['_id']}) removed_count += 1 continue @@ -370,7 +367,7 @@ class EmojiManager: if not os.path.exists(emoji['path']): logger.warning(f"表情包文件已被删除: {emoji['path']}") # 从数据库中删除记录 - result = self.db.emoji.delete_one({'_id': emoji['_id']}) + result = db.emoji.delete_one({'_id': emoji['_id']}) if result.deleted_count > 0: logger.debug(f"成功删除数据库记录: {emoji['_id']}") removed_count += 1 @@ -381,7 +378,7 @@ class EmojiManager: continue # 验证清理结果 - remaining_count = self.db.emoji.count_documents({}) + remaining_count = db.emoji.count_documents({}) if removed_count > 0: logger.success(f"已清理 {removed_count} 个失效的表情包记录") logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}") diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index 84e1937b0..2e0c0eb1f 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union from nonebot import get_driver from loguru import logger -from ...common.database import Database +from ...common.database import db from ..models.utils_model import LLM_request from .config import global_config from .message import MessageRecv, MessageThinking, Message @@ -34,7 +34,6 @@ class ResponseGenerator: self.model_v25 = LLM_request( model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000 ) - self.db = Database.get_instance() self.current_model_type = "r1" # 默认使用 R1 async def generate_response( @@ -154,7 +153,7 @@ class ResponseGenerator: reasoning_content: str, ): """保存对话记录到数据库""" - self.db.reasoning_logs.insert_one( + db.reasoning_logs.insert_one( { "time": time.time(), "chat_id": message.chat_stream.stream_id, @@ -211,7 +210,6 @@ class ResponseGenerator: class InitiativeMessageGenerate: def __init__(self): - self.db = Database.get_instance() self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7) self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7) self.model_r1_distill = LLM_request( diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index c89bf3e07..a41ed51e2 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -3,7 +3,7 @@ import time from typing import Optional from loguru import logger -from ...common.database import Database +from ...common.database import db from ..memory_system.memory import hippocampus, memory_graph from ..moods.moods import MoodManager from ..schedule.schedule_generator import bot_schedule @@ -16,7 +16,6 @@ class PromptBuilder: def __init__(self): self.prompt_built = '' self.activate_messages = '' - self.db = Database.get_instance() @@ -76,7 +75,7 @@ class PromptBuilder: chat_in_group=True chat_talking_prompt = '' if stream_id: - chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) + chat_talking_prompt = get_recent_group_detailed_plain_text(stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_stream=chat_manager.get_stream(stream_id) if chat_stream.group_info: chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" @@ -199,7 +198,7 @@ class PromptBuilder: chat_talking_prompt = '' if group_id: - chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, + chat_talking_prompt = get_recent_group_detailed_plain_text(group_id, limit=global_config.MAX_CONTEXT_SIZE, combine=True) @@ -311,7 +310,7 @@ class PromptBuilder: {"$project": {"content": 1, "similarity": 1}} ] - results = list(self.db.knowledges.aggregate(pipeline)) + results = list(db.knowledges.aggregate(pipeline)) # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") if not results: diff --git a/src/plugins/chat/relationship_manager.py b/src/plugins/chat/relationship_manager.py index fbd8cec59..d604e6734 100644 --- a/src/plugins/chat/relationship_manager.py +++ b/src/plugins/chat/relationship_manager.py @@ -2,7 +2,7 @@ import asyncio from typing import Optional from loguru import logger -from ...common.database import Database +from ...common.database import db from .message_base import UserInfo from .chat_stream import ChatStream @@ -167,14 +167,12 @@ class RelationshipManager: async def load_all_relationships(self): """加载所有关系对象""" - db = Database.get_instance() all_relationships = db.relationships.find({}) for data in all_relationships: await self.load_relationship(data) async def _start_relationship_manager(self): """每5分钟自动保存一次关系数据""" - db = Database.get_instance() # 获取所有关系记录 all_relationships = db.relationships.find({}) # 依次加载每条记录 @@ -205,7 +203,6 @@ class RelationshipManager: age = relationship.age saved = relationship.saved - db = Database.get_instance() db.relationships.update_one( {'user_id': user_id, 'platform': platform}, {'$set': { diff --git a/src/plugins/chat/storage.py b/src/plugins/chat/storage.py index ec155bbe9..ad6662f2b 100644 --- a/src/plugins/chat/storage.py +++ b/src/plugins/chat/storage.py @@ -1,15 +1,12 @@ from typing import Optional, Union -from ...common.database import Database +from ...common.database import db from .message import MessageSending, MessageRecv from .chat_stream import ChatStream from loguru import logger class MessageStorage: - def __init__(self): - self.db = Database.get_instance() - async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None: """存储消息到数据库""" try: @@ -23,7 +20,7 @@ class MessageStorage: "detailed_plain_text": message.detailed_plain_text, "topic": topic, } - self.db.messages.insert_one(message_data) + db.messages.insert_one(message_data) except Exception: logger.exception("存储消息失败") diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 0d1afd055..f28d0e192 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -16,6 +16,7 @@ from .message import MessageRecv,Message from .message_base import UserInfo from .chat_stream import ChatStream from ..moods.moods import MoodManager +from ...common.database import db driver = get_driver() config = driver.config @@ -76,11 +77,10 @@ def calculate_information_content(text): return entropy -def get_cloest_chat_from_db(db, length: int, timestamp: str): +def get_closest_chat_from_db(length: int, timestamp: str): """从数据库中获取最接近指定时间戳的聊天记录 Args: - db: 数据库实例 length: 要获取的消息数量 timestamp: 时间戳 @@ -115,11 +115,10 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str): return [] -async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list: +async def get_recent_group_messages(chat_id:str, limit: int = 12) -> list: """从数据库获取群组最近的消息记录 Args: - db: Database实例 group_id: 群组ID limit: 获取消息数量,默认12条 @@ -161,7 +160,7 @@ async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list: return message_objects -def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False): +def get_recent_group_detailed_plain_text(chat_stream_id: int, limit: int = 12, combine=False): recent_messages = list(db.messages.find( {"chat_id": chat_stream_id}, { diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 94014b5b4..2154280de 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -10,7 +10,7 @@ import io from loguru import logger from nonebot import get_driver -from ...common.database import Database +from ...common.database import db from ..chat.config import global_config from ..models.utils_model import LLM_request driver = get_driver() @@ -23,13 +23,11 @@ class ImageManager: def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) - cls._instance.db = None cls._instance._initialized = False return cls._instance def __init__(self): if not self._initialized: - self.db = Database.get_instance() self._ensure_image_collection() self._ensure_description_collection() self._ensure_image_dir() @@ -42,20 +40,20 @@ class ImageManager: def _ensure_image_collection(self): """确保images集合存在并创建索引""" - if 'images' not in self.db.list_collection_names(): - self.db.create_collection('images') + if 'images' not in db.list_collection_names(): + db.create_collection('images') # 创建索引 - self.db.images.create_index([('hash', 1)], unique=True) - self.db.images.create_index([('url', 1)]) - self.db.images.create_index([('path', 1)]) + db.images.create_index([('hash', 1)], unique=True) + db.images.create_index([('url', 1)]) + db.images.create_index([('path', 1)]) def _ensure_description_collection(self): """确保image_descriptions集合存在并创建索引""" - if 'image_descriptions' not in self.db.list_collection_names(): - self.db.create_collection('image_descriptions') + if 'image_descriptions' not in db.list_collection_names(): + db.create_collection('image_descriptions') # 创建索引 - self.db.image_descriptions.create_index([('hash', 1)], unique=True) - self.db.image_descriptions.create_index([('type', 1)]) + db.image_descriptions.create_index([('hash', 1)], unique=True) + db.image_descriptions.create_index([('type', 1)]) def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 @@ -67,7 +65,7 @@ class ImageManager: Returns: Optional[str]: 描述文本,如果不存在则返回None """ - result= self.db.image_descriptions.find_one({ + result= db.image_descriptions.find_one({ 'hash': image_hash, 'type': description_type }) @@ -81,7 +79,7 @@ class ImageManager: description: 描述文本 description_type: 描述类型 ('emoji' 或 'image') """ - self.db.image_descriptions.update_one( + db.image_descriptions.update_one( {'hash': image_hash, 'type': description_type}, { '$set': { @@ -124,7 +122,7 @@ class ImageManager: image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # 查重 - existing = self.db.images.find_one({'hash': image_hash}) + existing = db.images.find_one({'hash': image_hash}) if existing: return existing['path'] @@ -145,7 +143,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - self.db.images.insert_one(image_doc) + db.images.insert_one(image_doc) return file_path @@ -162,7 +160,7 @@ class ImageManager: """ try: # 先查找是否已存在 - existing = self.db.images.find_one({'url': url}) + existing = db.images.find_one({'url': url}) if existing: return existing['path'] @@ -206,7 +204,7 @@ class ImageManager: Returns: bool: 是否存在 """ - return self.db.images.find_one({'url': url}) is not None + return db.images.find_one({'url': url}) is not None def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool: """检查图像是否已存在 @@ -229,7 +227,7 @@ class ImageManager: return False image_hash = hashlib.md5(image_bytes).hexdigest() - return self.db.images.find_one({'hash': image_hash}) is not None + return db.images.find_one({'hash': image_hash}) is not None except Exception as e: logger.error(f"检查哈希失败: {str(e)}") @@ -273,7 +271,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - self.db.images.update_one( + db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True @@ -335,7 +333,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - self.db.images.update_one( + db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py index d6ba8f3b2..df699f459 100644 --- a/src/plugins/memory_system/draw_memory.py +++ b/src/plugins/memory_system/draw_memory.py @@ -13,7 +13,7 @@ from loguru import logger root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) sys.path.append(root_path) -from src.common.database import Database # 使用正确的导入语法 +from src.common.database import db # 使用正确的导入语法 # 加载.env.dev文件 env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev') @@ -23,7 +23,6 @@ load_dotenv(env_path) class Memory_graph: def __init__(self): self.G = nx.Graph() # 使用 networkx 的图结构 - self.db = Database.get_instance() def connect_dot(self, concept1, concept2): self.G.add_edge(concept1, concept2) @@ -96,7 +95,7 @@ class Memory_graph: dot_data = { "concept": node } - self.db.store_memory_dots.insert_one(dot_data) + db.store_memory_dots.insert_one(dot_data) @property def dots(self): @@ -106,7 +105,7 @@ class Memory_graph: def get_random_chat_from_db(self, length: int, timestamp: str): # 从数据库中根据时间戳获取离其最近的聊天记录 chat_text = '' - closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 + closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 logger.info( f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") @@ -115,7 +114,7 @@ class Memory_graph: group_id = closest_record['group_id'] # 获取groupid # 获取该时间戳之后的length条消息,且groupid相同 chat_record = list( - self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit( + db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit( length)) for record in chat_record: time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) @@ -130,50 +129,39 @@ class Memory_graph: def save_graph_to_db(self): # 清空现有的图数据 - self.db.graph_data.delete_many({}) + db.graph_data.delete_many({}) # 保存节点 for node in self.G.nodes(data=True): node_data = { 'concept': node[0], 'memory_items': node[1].get('memory_items', []) # 默认为空列表 } - self.db.graph_data.nodes.insert_one(node_data) + db.graph_data.nodes.insert_one(node_data) # 保存边 for edge in self.G.edges(): edge_data = { 'source': edge[0], 'target': edge[1] } - self.db.graph_data.edges.insert_one(edge_data) + db.graph_data.edges.insert_one(edge_data) def load_graph_from_db(self): # 清空当前图 self.G.clear() # 加载节点 - nodes = self.db.graph_data.nodes.find() + nodes = db.graph_data.nodes.find() for node in nodes: memory_items = node.get('memory_items', []) if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] self.G.add_node(node['concept'], memory_items=memory_items) # 加载边 - edges = self.db.graph_data.edges.find() + edges = db.graph_data.edges.find() for edge in edges: self.G.add_edge(edge['source'], edge['target']) def main(): - # 初始化数据库 - Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), - ) - memory_graph = Memory_graph() memory_graph.load_graph_from_db() diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index d9e867e63..f87f037d5 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -10,12 +10,12 @@ import networkx as nx from loguru import logger from nonebot import get_driver -from ...common.database import Database # 使用正确的导入语法 +from ...common.database import db # 使用正确的导入语法 from ..chat.config import global_config from ..chat.utils import ( calculate_information_content, cosine_similarity, - get_cloest_chat_from_db, + get_closest_chat_from_db, text_to_vector, ) from ..models.utils_model import LLM_request @@ -23,7 +23,6 @@ from ..models.utils_model import LLM_request class Memory_graph: def __init__(self): self.G = nx.Graph() # 使用 networkx 的图结构 - self.db = Database.get_instance() def connect_dot(self, concept1, concept2): # 避免自连接 @@ -191,19 +190,19 @@ class Hippocampus: # 短期:1h 中期:4h 长期:24h for _ in range(time_frequency.get('near')): random_time = current_timestamp - random.randint(1, 3600) - messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) if messages: chat_samples.append(messages) for _ in range(time_frequency.get('mid')): random_time = current_timestamp - random.randint(3600, 3600 * 4) - messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) if messages: chat_samples.append(messages) for _ in range(time_frequency.get('far')): random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24) - messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) if messages: chat_samples.append(messages) @@ -349,7 +348,7 @@ class Hippocampus: def sync_memory_to_db(self): """检查并同步内存中的图结构与数据库""" # 获取数据库中所有节点和内存中所有节点 - db_nodes = list(self.memory_graph.db.graph_data.nodes.find()) + db_nodes = list(db.graph_data.nodes.find()) memory_nodes = list(self.memory_graph.G.nodes(data=True)) # 转换数据库节点为字典格式,方便查找 @@ -377,7 +376,7 @@ class Hippocampus: 'created_time': created_time, 'last_modified': last_modified } - self.memory_graph.db.graph_data.nodes.insert_one(node_data) + db.graph_data.nodes.insert_one(node_data) else: # 获取数据库中节点的特征值 db_node = db_nodes_dict[concept] @@ -385,7 +384,7 @@ class Hippocampus: # 如果特征值不同,则更新节点 if db_hash != memory_hash: - self.memory_graph.db.graph_data.nodes.update_one( + db.graph_data.nodes.update_one( {'concept': concept}, {'$set': { 'memory_items': memory_items, @@ -396,7 +395,7 @@ class Hippocampus: ) # 处理边的信息 - db_edges = list(self.memory_graph.db.graph_data.edges.find()) + db_edges = list(db.graph_data.edges.find()) memory_edges = list(self.memory_graph.G.edges(data=True)) # 创建边的哈希值字典 @@ -428,11 +427,11 @@ class Hippocampus: 'created_time': created_time, 'last_modified': last_modified } - self.memory_graph.db.graph_data.edges.insert_one(edge_data) + db.graph_data.edges.insert_one(edge_data) else: # 检查边的特征值是否变化 if db_edge_dict[edge_key]['hash'] != edge_hash: - self.memory_graph.db.graph_data.edges.update_one( + db.graph_data.edges.update_one( {'source': source, 'target': target}, {'$set': { 'hash': edge_hash, @@ -451,7 +450,7 @@ class Hippocampus: self.memory_graph.G.clear() # 从数据库加载所有节点 - nodes = list(self.memory_graph.db.graph_data.nodes.find()) + nodes = list(db.graph_data.nodes.find()) for node in nodes: concept = node['concept'] memory_items = node.get('memory_items', []) @@ -468,7 +467,7 @@ class Hippocampus: if 'last_modified' not in node: update_data['last_modified'] = current_time - self.memory_graph.db.graph_data.nodes.update_one( + db.graph_data.nodes.update_one( {'concept': concept}, {'$set': update_data} ) @@ -485,7 +484,7 @@ class Hippocampus: last_modified=last_modified) # 从数据库加载所有边 - edges = list(self.memory_graph.db.graph_data.edges.find()) + edges = list(db.graph_data.edges.find()) for edge in edges: source = edge['source'] target = edge['target'] @@ -501,7 +500,7 @@ class Hippocampus: if 'last_modified' not in edge: update_data['last_modified'] = current_time - self.memory_graph.db.graph_data.edges.update_one( + db.graph_data.edges.update_one( {'source': source, 'target': target}, {'$set': update_data} ) diff --git a/src/plugins/memory_system/memory_manual_build.py b/src/plugins/memory_system/memory_manual_build.py index adf972a06..2d16998e0 100644 --- a/src/plugins/memory_system/memory_manual_build.py +++ b/src/plugins/memory_system/memory_manual_build.py @@ -19,7 +19,7 @@ import jieba root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) sys.path.append(root_path) -from src.common.database import Database +from src.common.database import db from src.plugins.memory_system.offline_llm import LLMModel # 获取当前文件的目录 @@ -49,7 +49,7 @@ def calculate_information_content(text): return entropy -def get_cloest_chat_from_db(db, length: int, timestamp: str): +def get_closest_chat_from_db(length: int, timestamp: str): """从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数 Returns: @@ -91,7 +91,6 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str): class Memory_graph: def __init__(self): self.G = nx.Graph() # 使用 networkx 的图结构 - self.db = Database.get_instance() def connect_dot(self, concept1, concept2): # 如果边已存在,增加 strength @@ -186,19 +185,19 @@ class Hippocampus: # 短期:1h 中期:4h 长期:24h for _ in range(time_frequency.get('near')): random_time = current_timestamp - random.randint(1, 3600*4) - messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) if messages: chat_samples.append(messages) for _ in range(time_frequency.get('mid')): random_time = current_timestamp - random.randint(3600*4, 3600*24) - messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) if messages: chat_samples.append(messages) for _ in range(time_frequency.get('far')): random_time = current_timestamp - random.randint(3600*24, 3600*24*7) - messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) if messages: chat_samples.append(messages) @@ -323,7 +322,7 @@ class Hippocampus: self.memory_graph.G.clear() # 从数据库加载所有节点 - nodes = self.memory_graph.db.graph_data.nodes.find() + nodes = db.graph_data.nodes.find() for node in nodes: concept = node['concept'] memory_items = node.get('memory_items', []) @@ -334,7 +333,7 @@ class Hippocampus: self.memory_graph.G.add_node(concept, memory_items=memory_items) # 从数据库加载所有边 - edges = self.memory_graph.db.graph_data.edges.find() + edges = db.graph_data.edges.find() for edge in edges: source = edge['source'] target = edge['target'] @@ -371,7 +370,7 @@ class Hippocampus: 使用特征值(哈希值)快速判断是否需要更新 """ # 获取数据库中所有节点和内存中所有节点 - db_nodes = list(self.memory_graph.db.graph_data.nodes.find()) + db_nodes = list(db.graph_data.nodes.find()) memory_nodes = list(self.memory_graph.G.nodes(data=True)) # 转换数据库节点为字典格式,方便查找 @@ -394,7 +393,7 @@ class Hippocampus: 'memory_items': memory_items, 'hash': memory_hash } - self.memory_graph.db.graph_data.nodes.insert_one(node_data) + db.graph_data.nodes.insert_one(node_data) else: # 获取数据库中节点的特征值 db_node = db_nodes_dict[concept] @@ -403,7 +402,7 @@ class Hippocampus: # 如果特征值不同,则更新节点 if db_hash != memory_hash: # logger.info(f"更新节点内容: {concept}") - self.memory_graph.db.graph_data.nodes.update_one( + db.graph_data.nodes.update_one( {'concept': concept}, {'$set': { 'memory_items': memory_items, @@ -416,10 +415,10 @@ class Hippocampus: for db_node in db_nodes: if db_node['concept'] not in memory_concepts: # logger.info(f"删除多余节点: {db_node['concept']}") - self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']}) + db.graph_data.nodes.delete_one({'concept': db_node['concept']}) # 处理边的信息 - db_edges = list(self.memory_graph.db.graph_data.edges.find()) + db_edges = list(db.graph_data.edges.find()) memory_edges = list(self.memory_graph.G.edges()) # 创建边的哈希值字典 @@ -445,12 +444,12 @@ class Hippocampus: 'num': 1, 'hash': edge_hash } - self.memory_graph.db.graph_data.edges.insert_one(edge_data) + db.graph_data.edges.insert_one(edge_data) else: # 检查边的特征值是否变化 if db_edge_dict[edge_key]['hash'] != edge_hash: logger.info(f"更新边: {source} - {target}") - self.memory_graph.db.graph_data.edges.update_one( + db.graph_data.edges.update_one( {'source': source, 'target': target}, {'$set': {'hash': edge_hash}} ) @@ -461,7 +460,7 @@ class Hippocampus: if edge_key not in memory_edge_set: source, target = edge_key logger.info(f"删除多余边: {source} - {target}") - self.memory_graph.db.graph_data.edges.delete_one({ + db.graph_data.edges.delete_one({ 'source': source, 'target': target }) @@ -487,9 +486,9 @@ class Hippocampus: topic: 要删除的节点概念 """ # 删除节点 - self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic}) + db.graph_data.nodes.delete_one({'concept': topic}) # 删除所有涉及该节点的边 - self.memory_graph.db.graph_data.edges.delete_many({ + db.graph_data.edges.delete_many({ '$or': [ {'source': topic}, {'target': topic} @@ -902,17 +901,6 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal plt.show() async def main(): - # 初始化数据库 - logger.info("正在初始化数据库连接...") - Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), - ) start_time = time.time() test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False} diff --git a/src/plugins/memory_system/memory_test1.py b/src/plugins/memory_system/memory_test1.py index f86c8ea3d..245eb9b26 100644 --- a/src/plugins/memory_system/memory_test1.py +++ b/src/plugins/memory_system/memory_test1.py @@ -38,7 +38,7 @@ import jieba # from chat.config import global_config sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 -from src.common.database import Database +from src.common.database import db from src.plugins.memory_system.offline_llm import LLMModel # 获取当前文件的目录 @@ -56,45 +56,6 @@ else: logger.warning(f"未找到环境变量文件: {env_path}") logger.info("将使用默认配置") -class Database: - _instance = None - db = None - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls() - return cls._instance - - def __init__(self): - if not Database.db: - Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), - ) - - @classmethod - def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"): - try: - if username and password: - uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}" - else: - uri = f"mongodb://{host}:{port}" - - client = pymongo.MongoClient(uri) - cls.db = client[db_name] - # 测试连接 - client.server_info() - logger.success("MongoDB连接成功!") - - except Exception as e: - logger.error(f"初始化MongoDB失败: {str(e)}") - raise def calculate_information_content(text): """计算文本的信息量(熵)""" @@ -108,7 +69,7 @@ def calculate_information_content(text): return entropy -def get_cloest_chat_from_db(db, length: int, timestamp: str): +def get_closest_chat_from_db(length: int, timestamp: str): """从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数 Returns: @@ -163,7 +124,7 @@ class Memory_cortex: default_time = datetime.datetime.now().timestamp() # 从数据库加载所有节点 - nodes = self.memory_graph.db.graph_data.nodes.find() + nodes = db.graph_data.nodes.find() for node in nodes: concept = node['concept'] memory_items = node.get('memory_items', []) @@ -180,7 +141,7 @@ class Memory_cortex: created_time = default_time last_modified = default_time # 更新数据库中的节点 - self.memory_graph.db.graph_data.nodes.update_one( + db.graph_data.nodes.update_one( {'concept': concept}, {'$set': { 'created_time': created_time, @@ -196,7 +157,7 @@ class Memory_cortex: last_modified=last_modified) # 从数据库加载所有边 - edges = self.memory_graph.db.graph_data.edges.find() + edges = db.graph_data.edges.find() for edge in edges: source = edge['source'] target = edge['target'] @@ -212,7 +173,7 @@ class Memory_cortex: created_time = default_time last_modified = default_time # 更新数据库中的边 - self.memory_graph.db.graph_data.edges.update_one( + db.graph_data.edges.update_one( {'source': source, 'target': target}, {'$set': { 'created_time': created_time, @@ -256,7 +217,7 @@ class Memory_cortex: current_time = datetime.datetime.now().timestamp() # 获取数据库中所有节点和内存中所有节点 - db_nodes = list(self.memory_graph.db.graph_data.nodes.find()) + db_nodes = list(db.graph_data.nodes.find()) memory_nodes = list(self.memory_graph.G.nodes(data=True)) # 转换数据库节点为字典格式,方便查找 @@ -280,7 +241,7 @@ class Memory_cortex: 'created_time': data.get('created_time', current_time), 'last_modified': data.get('last_modified', current_time) } - self.memory_graph.db.graph_data.nodes.insert_one(node_data) + db.graph_data.nodes.insert_one(node_data) else: # 获取数据库中节点的特征值 db_node = db_nodes_dict[concept] @@ -288,7 +249,7 @@ class Memory_cortex: # 如果特征值不同,则更新节点 if db_hash != memory_hash: - self.memory_graph.db.graph_data.nodes.update_one( + db.graph_data.nodes.update_one( {'concept': concept}, {'$set': { 'memory_items': memory_items, @@ -301,10 +262,10 @@ class Memory_cortex: memory_concepts = set(node[0] for node in memory_nodes) for db_node in db_nodes: if db_node['concept'] not in memory_concepts: - self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']}) + db.graph_data.nodes.delete_one({'concept': db_node['concept']}) # 处理边的信息 - db_edges = list(self.memory_graph.db.graph_data.edges.find()) + db_edges = list(db.graph_data.edges.find()) memory_edges = list(self.memory_graph.G.edges(data=True)) # 创建边的哈希值字典 @@ -332,11 +293,11 @@ class Memory_cortex: 'created_time': data.get('created_time', current_time), 'last_modified': data.get('last_modified', current_time) } - self.memory_graph.db.graph_data.edges.insert_one(edge_data) + db.graph_data.edges.insert_one(edge_data) else: # 检查边的特征值是否变化 if db_edge_dict[edge_key]['hash'] != edge_hash: - self.memory_graph.db.graph_data.edges.update_one( + db.graph_data.edges.update_one( {'source': source, 'target': target}, {'$set': { 'hash': edge_hash, @@ -350,7 +311,7 @@ class Memory_cortex: for edge_key in db_edge_dict: if edge_key not in memory_edge_set: source, target = edge_key - self.memory_graph.db.graph_data.edges.delete_one({ + db.graph_data.edges.delete_one({ 'source': source, 'target': target }) @@ -365,9 +326,9 @@ class Memory_cortex: topic: 要删除的节点概念 """ # 删除节点 - self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic}) + db.graph_data.nodes.delete_one({'concept': topic}) # 删除所有涉及该节点的边 - self.memory_graph.db.graph_data.edges.delete_many({ + db.graph_data.edges.delete_many({ '$or': [ {'source': topic}, {'target': topic} @@ -377,7 +338,6 @@ class Memory_cortex: class Memory_graph: def __init__(self): self.G = nx.Graph() # 使用 networkx 的图结构 - self.db = Database.get_instance() def connect_dot(self, concept1, concept2): # 避免自连接 @@ -492,19 +452,19 @@ class Hippocampus: # 短期:1h 中期:4h 长期:24h for _ in range(time_frequency.get('near')): random_time = current_timestamp - random.randint(1, 3600*4) - messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) if messages: chat_samples.append(messages) for _ in range(time_frequency.get('mid')): random_time = current_timestamp - random.randint(3600*4, 3600*24) - messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) if messages: chat_samples.append(messages) for _ in range(time_frequency.get('far')): random_time = current_timestamp - random.randint(3600*24, 3600*24*7) - messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) + messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time) if messages: chat_samples.append(messages) @@ -1134,7 +1094,6 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal async def main(): # 初始化数据库 logger.info("正在初始化数据库连接...") - db = Database.get_instance() start_time = time.time() test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False} diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index aa07bb55d..afe4baeb5 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -10,7 +10,7 @@ from nonebot import get_driver import base64 from PIL import Image import io -from ...common.database import Database +from ...common.database import db from ..chat.config import global_config driver = get_driver() @@ -34,17 +34,16 @@ class LLM_request: self.pri_out = model.get("pri_out", 0) # 获取数据库实例 - self.db = Database.get_instance() self._init_database() def _init_database(self): """初始化数据库集合""" try: # 创建llm_usage集合的索引 - self.db.llm_usage.create_index([("timestamp", 1)]) - self.db.llm_usage.create_index([("model_name", 1)]) - self.db.llm_usage.create_index([("user_id", 1)]) - self.db.llm_usage.create_index([("request_type", 1)]) + db.llm_usage.create_index([("timestamp", 1)]) + db.llm_usage.create_index([("model_name", 1)]) + db.llm_usage.create_index([("user_id", 1)]) + db.llm_usage.create_index([("request_type", 1)]) except Exception: logger.error("创建数据库索引失败") @@ -73,7 +72,7 @@ class LLM_request: "status": "success", "timestamp": datetime.now() } - self.db.llm_usage.insert_one(usage_data) + db.llm_usage.insert_one(usage_data) logger.info( f"Token使用情况 - 模型: {self.model_name}, " f"用户: {user_id}, 类型: {request_type}, " diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index bde593890..5f62d6aca 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -8,7 +8,7 @@ from nonebot import get_driver from src.plugins.chat.config import global_config -from ...common.database import Database # 使用正确的导入语法 +from ...common.database import db # 使用正确的导入语法 from ..models.utils_model import LLM_request driver = get_driver() @@ -19,7 +19,6 @@ class ScheduleGenerator: # 根据global_config.llm_normal这一字典配置指定模型 # self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9) self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9) - self.db = Database.get_instance() self.today_schedule_text = "" self.today_schedule = {} self.tomorrow_schedule_text = "" @@ -46,7 +45,7 @@ class ScheduleGenerator: schedule_text = str - existing_schedule = self.db.schedule.find_one({"date": date_str}) + existing_schedule = db.schedule.find_one({"date": date_str}) if existing_schedule: logger.debug(f"{date_str}的日程已存在:") schedule_text = existing_schedule["schedule"] @@ -63,7 +62,7 @@ class ScheduleGenerator: try: schedule_text, _ = await self.llm_scheduler.generate_response(prompt) - self.db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) + db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) except Exception as e: logger.error(f"生成日程失败: {str(e)}") schedule_text = "生成日程时出错了" @@ -143,7 +142,7 @@ class ScheduleGenerator: """打印完整的日程安排""" if not self._parse_schedule(self.today_schedule_text): logger.warning("今日日程有误,将在下次运行时重新生成") - self.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) + db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) else: logger.info("=== 今日日程安排 ===") for time_str, activity in self.today_schedule.items(): diff --git a/src/plugins/utils/statistic.py b/src/plugins/utils/statistic.py index 4629f0e0b..e812bce4b 100644 --- a/src/plugins/utils/statistic.py +++ b/src/plugins/utils/statistic.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta from typing import Any, Dict from loguru import logger -from ...common.database import Database +from ...common.database import db class LLMStatistics: @@ -15,7 +15,6 @@ class LLMStatistics: Args: output_file: 统计结果输出文件路径 """ - self.db = Database.get_instance() self.output_file = output_file self.running = False self.stats_thread = None @@ -53,7 +52,7 @@ class LLMStatistics: "costs_by_model": defaultdict(float) } - cursor = self.db.llm_usage.find({ + cursor = db.llm_usage.find({ "timestamp": {"$gte": start_time} }) diff --git a/src/plugins/zhishi/knowledge_library.py b/src/plugins/zhishi/knowledge_library.py index ad309814b..a049394fe 100644 --- a/src/plugins/zhishi/knowledge_library.py +++ b/src/plugins/zhishi/knowledge_library.py @@ -14,7 +14,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) sys.path.append(root_path) # 现在可以导入src模块 -from src.common.database import Database +from src.common.database import db # 加载根目录下的env.edv文件 env_path = os.path.join(root_path, ".env.prod") @@ -24,18 +24,6 @@ load_dotenv(env_path) class KnowledgeLibrary: def __init__(self): - # 初始化数据库连接 - if Database._instance is None: - Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), - ) - self.db = Database.get_instance() self.raw_info_dir = "data/raw_info" self._ensure_dirs() self.api_key = os.getenv("SILICONFLOW_KEY") @@ -176,7 +164,7 @@ class KnowledgeLibrary: try: current_hash = self.calculate_file_hash(file_path) - processed_record = self.db.processed_files.find_one({"file_path": file_path}) + processed_record = db.processed_files.find_one({"file_path": file_path}) if processed_record: if processed_record.get("hash") == current_hash: @@ -197,14 +185,14 @@ class KnowledgeLibrary: "split_length": knowledge_length, "created_at": datetime.now() } - self.db.knowledges.insert_one(knowledge) + db.knowledges.insert_one(knowledge) result["chunks_processed"] += 1 split_by = processed_record.get("split_by", []) if processed_record else [] if knowledge_length not in split_by: split_by.append(knowledge_length) - self.db.knowledges.processed_files.update_one( + db.knowledges.processed_files.update_one( {"file_path": file_path}, { "$set": { @@ -322,7 +310,7 @@ class KnowledgeLibrary: {"$project": {"content": 1, "similarity": 1, "file_path": 1}} ] - results = list(self.db.knowledges.aggregate(pipeline)) + results = list(db.knowledges.aggregate(pipeline)) return results # 创建单例实例 @@ -346,7 +334,7 @@ if __name__ == "__main__": elif choice == '2': confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower() if confirm == 'y': - knowledge_library.db.knowledges.delete_many({}) + db.knowledges.delete_many({}) console.print("[green]已清空所有知识![/green]") continue elif choice == '1':