feat: 实现MongoDB URI方式连接,并统一数据库连接代码。

This commit is contained in:
春河晴
2025-03-10 14:48:43 +09:00
parent c9f12446c0
commit 4baa6c6f0a
9 changed files with 82 additions and 121 deletions

View File

@@ -31,6 +31,7 @@ driver = get_driver()
config = driver.config
Database.initialize(
uri=config.MONGODB_URI,
host=config.MONGODB_HOST,
port=int(config.MONGODB_PORT),
db_name=config.DATABASE_NAME,

View File

@@ -37,14 +37,7 @@ def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
os.makedirs(images_dir, exist_ok=True)
# 连接数据库
db = Database(
host=config.mongodb_host,
port=int(config.mongodb_port),
db_name=config.database_name,
username=config.mongodb_username,
password=config.mongodb_password,
auth_source=config.mongodb_auth_source
)
db = Database.get_instance()
# 检查是否已存在相同哈希值的图片
collection = db.db['images']

View File

@@ -19,12 +19,13 @@ from src.common.database import Database
# 从环境变量获取配置
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "maimai"),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "admin")
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class KnowledgeLibrary:

View File

@@ -162,12 +162,13 @@ class Memory_graph:
def main():
# 初始化数据库
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME", ""),
password=os.getenv("MONGODB_PASSWORD", ""),
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "")
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
memory_graph = Memory_graph()

View File

@@ -8,6 +8,7 @@ import jieba
import networkx as nx
from loguru import logger
from nonebot import get_driver
from ...common.database import Database # 使用正确的导入语法
from ..chat.config import global_config
from ..chat.utils import (
@@ -18,7 +19,6 @@ from ..chat.utils import (
)
from ..models.utils_model import LLM_request
class Memory_graph:
def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构
@@ -130,7 +130,7 @@ class Memory_graph:
return None
# 海马体
# 海马体
class Hippocampus:
def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph
@@ -749,15 +749,13 @@ def segment_text(text):
seg_text = list(jieba.cut(text))
return seg_text
from nonebot import get_driver
driver = get_driver()
config = driver.config
start_time = time.time()
Database.initialize(
uri=config.MONGODB_URI,
host=config.MONGODB_HOST,
port=config.MONGODB_PORT,
db_name=config.DATABASE_NAME,

View File

@@ -35,45 +35,6 @@ else:
logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置")
class Database:
_instance = None
db = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if not Database.db:
Database.initialize(
host=os.getenv("MONGODB_HOST"),
port=int(os.getenv("MONGODB_PORT")),
db_name=os.getenv("DATABASE_NAME"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE")
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
try:
if username and password:
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
else:
uri = f"mongodb://{host}:{port}"
client = pymongo.MongoClient(uri)
cls.db = client[db_name]
# 测试连接
client.server_info()
logger.success("MongoDB连接成功!")
except Exception as e:
logger.error(f"初始化MongoDB失败: {str(e)}")
raise
def calculate_information_content(text):
"""计算文本的信息量(熵)"""
char_count = Counter(text)
@@ -202,7 +163,7 @@ class Memory_graph:
# 返回所有节点对应的 Memory_dot 对象
return [self.get_dot(node) for node in self.G.nodes()]
# 海马体
# 海马体
class Hippocampus:
def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph
@@ -941,59 +902,67 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
async def main():
# 初始化数据库
logger.info("正在初始化数据库连接...")
db = Database.get_instance()
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
start_time = time.time()
test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}
# 创建记忆图
memory_graph = Memory_graph()
# 创建海马体
hippocampus = Hippocampus(memory_graph)
# 从数据库同步数据
hippocampus.sync_memory_from_db()
end_time = time.time()
logger.info(f"\033[32m[加载海马体耗时: {end_time - start_time:.2f} 秒]\033[0m")
# 构建记忆
if test_pare['do_build_memory']:
logger.info("开始构建记忆...")
chat_size = 20
await hippocampus.operation_build_memory(chat_size=chat_size)
end_time = time.time()
logger.info(f"\033[32m[构建记忆耗时: {end_time - start_time:.2f} 秒,chat_size={chat_size},chat_count = 16]\033[0m")
if test_pare['do_forget_topic']:
logger.info("开始遗忘记忆...")
await hippocampus.operation_forget_topic(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[遗忘记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_merge_memory']:
logger.info("开始合并记忆...")
await hippocampus.operation_merge_memory(percentage=0.1)
end_time = time.time()
logger.info(f"\033[32m[合并记忆耗时: {end_time - start_time:.2f} 秒]\033[0m")
if test_pare['do_visualize_graph']:
# 展示优化后的图形
logger.info("生成记忆图谱可视化...")
print("\n生成优化后的记忆图谱:")
visualize_graph_lite(memory_graph)
if test_pare['do_query']:
# 交互式查询
while True:
query = input("\n请输入新的查询概念(输入'退出'以结束):")
if query.lower() == '退出':
break
items_list = memory_graph.get_related_item(query)
if items_list:
first_layer, second_layer = items_list
@@ -1008,9 +977,6 @@ async def main():
else:
print("未找到相关记忆。")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -14,6 +14,7 @@ driver = get_driver()
config = driver.config
Database.initialize(
uri=config.MONGODB_URI,
host=config.MONGODB_HOST,
port=int(config.MONGODB_PORT),
db_name=config.DATABASE_NAME,
@@ -22,7 +23,6 @@ Database.initialize(
auth_source=config.MONGODB_AUTH_SOURCE
)
class ScheduleGenerator:
def __init__(self):
# 根据global_config.llm_normal这一字典配置指定模型