Merge pull request #287 from BBleae/debug

重构database.py为全局对象模式
This commit is contained in:
tcmofashi
2025-03-12 21:43:57 +08:00
committed by GitHub
20 changed files with 188 additions and 354 deletions

15
bot.py
View File

@@ -12,8 +12,6 @@ from loguru import logger
from nonebot.adapters.onebot.v11 import Adapter from nonebot.adapters.onebot.v11 import Adapter
import platform import platform
from src.common.database import Database
# 获取没有加载env时的环境变量 # 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ} env_mask = {key: os.getenv(key) for key in os.environ}
@@ -111,18 +109,6 @@ def load_env():
logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
def init_database():
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
def load_logger(): def load_logger():
logger.remove() # 移除默认配置 logger.remove() # 移除默认配置
if os.getenv("ENVIRONMENT") == "dev": if os.getenv("ENVIRONMENT") == "dev":
@@ -223,7 +209,6 @@ def raw_main():
init_config() init_config()
init_env() init_env()
load_env() load_env()
init_database() # 加载完成环境后初始化database
load_logger() load_logger()
env_config = {key: os.getenv(key) for key in os.environ} env_config = {key: os.getenv(key) for key in os.environ}

View File

@@ -1,73 +1,53 @@
from typing import Optional import os
from typing import cast
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.database import Database as MongoDatabase from pymongo.database import Database
class Database: _client = None
_instance: Optional["Database"] = None _db = None
def __create_database_instance():
uri = os.getenv("MONGODB_URI")
host = os.getenv("MONGODB_HOST", "127.0.0.1")
port = int(os.getenv("MONGODB_PORT", "27017"))
db_name = os.getenv("DATABASE_NAME", "MegBot")
username = os.getenv("MONGODB_USERNAME")
password = os.getenv("MONGODB_PASSWORD")
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
def __init__(
self,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
):
if uri and uri.startswith("mongodb://"): if uri and uri.startswith("mongodb://"):
# 优先使用URI连接 # 优先使用URI连接
self.client = MongoClient(uri) return MongoClient(uri)
elif username and password:
if username and password:
# 如果有用户名和密码,使用认证连接 # 如果有用户名和密码,使用认证连接
self.client = MongoClient( return MongoClient(
host, port, username=username, password=password, authSource=auth_source host, port, username=username, password=password, authSource=auth_source
) )
else:
# 否则使用无认证连接 # 否则使用无认证连接
self.client = MongoClient(host, port) return MongoClient(host, port)
self.db: MongoDatabase = self.client[db_name]
@classmethod
def initialize(
cls,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
) -> MongoDatabase:
if cls._instance is None:
cls._instance = cls(
host, port, db_name, username, password, auth_source, uri
)
return cls._instance.db
@classmethod
def get_instance(cls) -> MongoDatabase:
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance.db
#测试用 def get_db():
"""获取数据库连接实例,延迟初始化。"""
global _client, _db
if _client is None:
_client = __create_database_instance()
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
return _db
def get_random_group_messages(self, group_id: str, limit: int = 5):
# 先随机获取一条消息
random_message = list(self.db.messages.aggregate([
{"$match": {"group_id": group_id}},
{"$sample": {"size": 1}}
]))[0]
# 获取该消息之后的消息 class DBWrapper:
subsequent_messages = list(self.db.messages.find({ """数据库代理类,保持接口兼容性同时实现懒加载。"""
"group_id": group_id,
"time": {"$gt": random_message["time"]}
}).sort("time", 1).limit(limit))
# 将随机消息和后续消息合并 def __getattr__(self, name):
messages = [random_message] + subsequent_messages return getattr(get_db(), name)
return messages def __getitem__(self, key):
return get_db()[key]
# 全局数据库访问点
db: Database = DBWrapper()

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from typing import Dict, List from typing import Dict, List
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from ..common.database import Database from ..common.database import db
import customtkinter as ctk import customtkinter as ctk
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -44,28 +44,6 @@ class ReasoningGUI:
self.root.geometry('800x600') self.root.geometry('800x600')
self.root.protocol("WM_DELETE_WINDOW", self._on_closing) self.root.protocol("WM_DELETE_WINDOW", self._on_closing)
# 初始化数据库连接
try:
self.db = Database.get_instance()
logger.success("数据库连接成功")
except RuntimeError:
logger.warning("数据库未初始化,正在尝试初始化...")
try:
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance()
logger.success("数据库初始化成功")
except Exception:
logger.exception("数据库初始化失败")
sys.exit(1)
# 存储群组数据 # 存储群组数据
self.group_data: Dict[str, List[dict]] = {} self.group_data: Dict[str, List[dict]] = {}
@@ -264,11 +242,11 @@ class ReasoningGUI:
logger.debug(f"查询条件: {query}") logger.debug(f"查询条件: {query}")
# 先获取一条记录检查时间格式 # 先获取一条记录检查时间格式
sample = self.db.reasoning_logs.find_one() sample = db.reasoning_logs.find_one()
if sample: if sample:
logger.debug(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}") logger.debug(f"样本记录时间格式: {type(sample['time'])} 值: {sample['time']}")
cursor = self.db.reasoning_logs.find(query).sort("time", -1) cursor = db.reasoning_logs.find(query).sort("time", -1)
new_data = {} new_data = {}
total_count = 0 total_count = 0
@@ -333,17 +311,6 @@ class ReasoningGUI:
def main(): def main():
"""主函数"""
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
app = ReasoningGUI() app = ReasoningGUI()
app.run() app.run()

View File

@@ -7,7 +7,6 @@ from nonebot import get_driver, on_message, require
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment,MessageEvent from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment,MessageEvent
from nonebot.typing import T_State from nonebot.typing import T_State
from ...common.database import Database
from ..moods.moods import MoodManager # 导入情绪管理器 from ..moods.moods import MoodManager # 导入情绪管理器
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
from ..utils.statistic import LLMStatistics from ..utils.statistic import LLMStatistics

View File

@@ -6,7 +6,7 @@ from typing import Dict, Optional
from loguru import logger from loguru import logger
from ...common.database import Database from ...common.database import db
from .message_base import GroupInfo, UserInfo from .message_base import GroupInfo, UserInfo
@@ -83,7 +83,6 @@ class ChatManager:
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.db = Database.get_instance()
self._ensure_collection() self._ensure_collection()
self._initialized = True self._initialized = True
# 在事件循环中启动初始化 # 在事件循环中启动初始化
@@ -111,11 +110,11 @@ class ChatManager:
def _ensure_collection(self): def _ensure_collection(self):
"""确保数据库集合存在并创建索引""" """确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.list_collection_names(): if "chat_streams" not in db.list_collection_names():
self.db.create_collection("chat_streams") db.create_collection("chat_streams")
# 创建索引 # 创建索引
self.db.chat_streams.create_index([("stream_id", 1)], unique=True) db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.chat_streams.create_index( db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)] [("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
) )
@@ -168,7 +167,7 @@ class ChatManager:
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
data = self.db.chat_streams.find_one({"stream_id": stream_id}) data = db.chat_streams.find_one({"stream_id": stream_id})
if data: if data:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息 # 更新用户信息和群组信息
@@ -204,7 +203,7 @@ class ChatManager:
async def _save_stream(self, stream: ChatStream): async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库""" """保存聊天流到数据库"""
if not stream.saved: if not stream.saved:
self.db.chat_streams.update_one( db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True {"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
) )
stream.saved = True stream.saved = True
@@ -216,7 +215,7 @@ class ChatManager:
async def load_all_streams(self): async def load_all_streams(self):
"""从数据库加载所有聊天流""" """从数据库加载所有聊天流"""
all_streams = self.db.chat_streams.find({}) all_streams = db.chat_streams.find({})
for data in all_streams: for data in all_streams:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream self.streams[stream.stream_id] = stream

View File

@@ -12,7 +12,7 @@ import io
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from ...common.database import Database from ...common.database import db
from ..chat.config import global_config from ..chat.config import global_config
from ..chat.utils import get_embedding from ..chat.utils import get_embedding
from ..chat.utils_image import ImageManager, image_path_to_base64 from ..chat.utils_image import ImageManager, image_path_to_base64
@@ -30,12 +30,10 @@ class EmojiManager:
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False cls._instance._initialized = False
return cls._instance return cls._instance
def __init__(self): def __init__(self):
self.db = Database.get_instance()
self._scan_task = None self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60, self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60,
@@ -50,7 +48,6 @@ class EmojiManager:
"""初始化数据库连接和表情目录""" """初始化数据库连接和表情目录"""
if not self._initialized: if not self._initialized:
try: try:
self.db = Database.get_instance()
self._ensure_emoji_collection() self._ensure_emoji_collection()
self._ensure_emoji_dir() self._ensure_emoji_dir()
self._initialized = True self._initialized = True
@@ -78,16 +75,16 @@ class EmojiManager:
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
""" """
if 'emoji' not in self.db.list_collection_names(): if 'emoji' not in db.list_collection_names():
self.db.create_collection('emoji') db.create_collection('emoji')
self.db.emoji.create_index([('embedding', '2dsphere')]) db.emoji.create_index([('embedding', '2dsphere')])
self.db.emoji.create_index([('filename', 1)], unique=True) db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str): def record_usage(self, emoji_id: str):
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
self._ensure_db() self._ensure_db()
self.db.emoji.update_one( db.emoji.update_one(
{'_id': emoji_id}, {'_id': emoji_id},
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
@@ -121,7 +118,7 @@ class EmojiManager:
try: try:
# 获取所有表情包 # 获取所有表情包
all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) all_emojis = list(db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
if not all_emojis: if not all_emojis:
logger.warning("数据库中没有任何表情包") logger.warning("数据库中没有任何表情包")
@@ -159,7 +156,7 @@ class EmojiManager:
if selected_emoji and 'path' in selected_emoji: if selected_emoji and 'path' in selected_emoji:
# 更新使用次数 # 更新使用次数
self.db.emoji.update_one( db.emoji.update_one(
{'_id': selected_emoji['_id']}, {'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
@@ -241,14 +238,14 @@ class EmojiManager:
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 检查是否已经注册过 # 检查是否已经注册过
existing_emoji = self.db['emoji'].find_one({'filename': filename}) existing_emoji = db['emoji'].find_one({'filename': filename})
description = None description = None
if existing_emoji: if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合 # 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription') description = existing_emoji.get('discription')
# 检查是否在images集合中存在 # 检查是否在images集合中存在
existing_image = image_manager.db.images.find_one({'hash': image_hash}) existing_image = db.images.find_one({'hash': image_hash})
if not existing_image: if not existing_image:
# 同步到images集合 # 同步到images集合
image_doc = { image_doc = {
@@ -258,7 +255,7 @@ class EmojiManager:
'description': description, 'description': description,
'timestamp': int(time.time()) 'timestamp': int(time.time())
} }
image_manager.db.images.update_one( db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True
@@ -307,7 +304,7 @@ class EmojiManager:
} }
# 保存到emoji数据库 # 保存到emoji数据库
self.db['emoji'].insert_one(emoji_record) db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}") logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}") logger.info(f"描述: {description}")
@@ -320,7 +317,7 @@ class EmojiManager:
'description': description, 'description': description,
'timestamp': int(time.time()) 'timestamp': int(time.time())
} }
image_manager.db.images.update_one( db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True
@@ -348,7 +345,7 @@ class EmojiManager:
try: try:
self._ensure_db() self._ensure_db()
# 获取所有表情包记录 # 获取所有表情包记录
all_emojis = list(self.db.emoji.find()) all_emojis = list(db.emoji.find())
removed_count = 0 removed_count = 0
total_count = len(all_emojis) total_count = len(all_emojis)
@@ -356,13 +353,13 @@ class EmojiManager:
try: try:
if 'path' not in emoji: if 'path' not in emoji:
logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}") logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
self.db.emoji.delete_one({'_id': emoji['_id']}) db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1 removed_count += 1
continue continue
if 'embedding' not in emoji: if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}") logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.emoji.delete_one({'_id': emoji['_id']}) db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1 removed_count += 1
continue continue
@@ -370,7 +367,7 @@ class EmojiManager:
if not os.path.exists(emoji['path']): if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}") logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录 # 从数据库中删除记录
result = self.db.emoji.delete_one({'_id': emoji['_id']}) result = db.emoji.delete_one({'_id': emoji['_id']})
if result.deleted_count > 0: if result.deleted_count > 0:
logger.debug(f"成功删除数据库记录: {emoji['_id']}") logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1 removed_count += 1
@@ -381,7 +378,7 @@ class EmojiManager:
continue continue
# 验证清理结果 # 验证清理结果
remaining_count = self.db.emoji.count_documents({}) remaining_count = db.emoji.count_documents({})
if removed_count > 0: if removed_count > 0:
logger.success(f"已清理 {removed_count} 个失效的表情包记录") logger.success(f"已清理 {removed_count} 个失效的表情包记录")
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}") logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")

View File

@@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union
from nonebot import get_driver from nonebot import get_driver
from loguru import logger from loguru import logger
from ...common.database import Database from ...common.database import db
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .message import MessageRecv, MessageThinking, Message from .message import MessageRecv, MessageThinking, Message
@@ -34,7 +34,6 @@ class ResponseGenerator:
self.model_v25 = LLM_request( self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000 model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
) )
self.db = Database.get_instance()
self.current_model_type = "r1" # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
async def generate_response( async def generate_response(
@@ -154,7 +153,7 @@ class ResponseGenerator:
reasoning_content: str, reasoning_content: str,
): ):
"""保存对话记录到数据库""" """保存对话记录到数据库"""
self.db.reasoning_logs.insert_one( db.reasoning_logs.insert_one(
{ {
"time": time.time(), "time": time.time(),
"chat_id": message.chat_stream.stream_id, "chat_id": message.chat_stream.stream_id,
@@ -211,7 +210,6 @@ class ResponseGenerator:
class InitiativeMessageGenerate: class InitiativeMessageGenerate:
def __init__(self): def __init__(self):
self.db = Database.get_instance()
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7) self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7) self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
self.model_r1_distill = LLM_request( self.model_r1_distill = LLM_request(

View File

@@ -3,7 +3,7 @@ import time
from typing import Optional from typing import Optional
from loguru import logger from loguru import logger
from ...common.database import Database from ...common.database import db
from ..memory_system.memory import hippocampus, memory_graph from ..memory_system.memory import hippocampus, memory_graph
from ..moods.moods import MoodManager from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
@@ -16,7 +16,6 @@ class PromptBuilder:
def __init__(self): def __init__(self):
self.prompt_built = '' self.prompt_built = ''
self.activate_messages = '' self.activate_messages = ''
self.db = Database.get_instance()
@@ -76,7 +75,7 @@ class PromptBuilder:
chat_in_group=True chat_in_group=True
chat_talking_prompt = '' chat_talking_prompt = ''
if stream_id: if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_talking_prompt = get_recent_group_detailed_plain_text(stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_stream=chat_manager.get_stream(stream_id) chat_stream=chat_manager.get_stream(stream_id)
if chat_stream.group_info: if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
@@ -199,7 +198,7 @@ class PromptBuilder:
chat_talking_prompt = '' chat_talking_prompt = ''
if group_id: if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, chat_talking_prompt = get_recent_group_detailed_plain_text(group_id,
limit=global_config.MAX_CONTEXT_SIZE, limit=global_config.MAX_CONTEXT_SIZE,
combine=True) combine=True)
@@ -311,7 +310,7 @@ class PromptBuilder:
{"$project": {"content": 1, "similarity": 1}} {"$project": {"content": 1, "similarity": 1}}
] ]
results = list(self.db.knowledges.aggregate(pipeline)) results = list(db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results: if not results:

View File

@@ -2,7 +2,7 @@ import asyncio
from typing import Optional from typing import Optional
from loguru import logger from loguru import logger
from ...common.database import Database from ...common.database import db
from .message_base import UserInfo from .message_base import UserInfo
from .chat_stream import ChatStream from .chat_stream import ChatStream
@@ -167,14 +167,12 @@ class RelationshipManager:
async def load_all_relationships(self): async def load_all_relationships(self):
"""加载所有关系对象""" """加载所有关系对象"""
db = Database.get_instance()
all_relationships = db.relationships.find({}) all_relationships = db.relationships.find({})
for data in all_relationships: for data in all_relationships:
await self.load_relationship(data) await self.load_relationship(data)
async def _start_relationship_manager(self): async def _start_relationship_manager(self):
"""每5分钟自动保存一次关系数据""" """每5分钟自动保存一次关系数据"""
db = Database.get_instance()
# 获取所有关系记录 # 获取所有关系记录
all_relationships = db.relationships.find({}) all_relationships = db.relationships.find({})
# 依次加载每条记录 # 依次加载每条记录
@@ -205,7 +203,6 @@ class RelationshipManager:
age = relationship.age age = relationship.age
saved = relationship.saved saved = relationship.saved
db = Database.get_instance()
db.relationships.update_one( db.relationships.update_one(
{'user_id': user_id, 'platform': platform}, {'user_id': user_id, 'platform': platform},
{'$set': { {'$set': {

View File

@@ -1,15 +1,12 @@
from typing import Optional, Union from typing import Optional, Union
from ...common.database import Database from ...common.database import db
from .message import MessageSending, MessageRecv from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream from .chat_stream import ChatStream
from loguru import logger from loguru import logger
class MessageStorage: class MessageStorage:
def __init__(self):
self.db = Database.get_instance()
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None: async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
try: try:
@@ -23,7 +20,7 @@ class MessageStorage:
"detailed_plain_text": message.detailed_plain_text, "detailed_plain_text": message.detailed_plain_text,
"topic": topic, "topic": topic,
} }
self.db.messages.insert_one(message_data) db.messages.insert_one(message_data)
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")

View File

@@ -16,6 +16,7 @@ from .message import MessageRecv,Message
from .message_base import UserInfo from .message_base import UserInfo
from .chat_stream import ChatStream from .chat_stream import ChatStream
from ..moods.moods import MoodManager from ..moods.moods import MoodManager
from ...common.database import db
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
@@ -76,11 +77,10 @@ def calculate_information_content(text):
return entropy return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str): def get_closest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录 """从数据库中获取最接近指定时间戳的聊天记录
Args: Args:
db: 数据库实例
length: 要获取的消息数量 length: 要获取的消息数量
timestamp: 时间戳 timestamp: 时间戳
@@ -115,11 +115,10 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
return [] return []
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list: async def get_recent_group_messages(chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录
Args: Args:
db: Database实例
group_id: 群组ID group_id: 群组ID
limit: 获取消息数量默认12条 limit: 获取消息数量默认12条
@@ -161,7 +160,7 @@ async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
return message_objects return message_objects
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False): def get_recent_group_detailed_plain_text(chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.messages.find( recent_messages = list(db.messages.find(
{"chat_id": chat_stream_id}, {"chat_id": chat_stream_id},
{ {

View File

@@ -10,7 +10,7 @@ import io
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from ...common.database import Database from ...common.database import db
from ..chat.config import global_config from ..chat.config import global_config
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
driver = get_driver() driver = get_driver()
@@ -23,13 +23,11 @@ class ImageManager:
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False cls._instance._initialized = False
return cls._instance return cls._instance
def __init__(self): def __init__(self):
if not self._initialized: if not self._initialized:
self.db = Database.get_instance()
self._ensure_image_collection() self._ensure_image_collection()
self._ensure_description_collection() self._ensure_description_collection()
self._ensure_image_dir() self._ensure_image_dir()
@@ -42,20 +40,20 @@ class ImageManager:
def _ensure_image_collection(self): def _ensure_image_collection(self):
"""确保images集合存在并创建索引""" """确保images集合存在并创建索引"""
if 'images' not in self.db.list_collection_names(): if 'images' not in db.list_collection_names():
self.db.create_collection('images') db.create_collection('images')
# 创建索引 # 创建索引
self.db.images.create_index([('hash', 1)], unique=True) db.images.create_index([('hash', 1)], unique=True)
self.db.images.create_index([('url', 1)]) db.images.create_index([('url', 1)])
self.db.images.create_index([('path', 1)]) db.images.create_index([('path', 1)])
def _ensure_description_collection(self): def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引""" """确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.list_collection_names(): if 'image_descriptions' not in db.list_collection_names():
self.db.create_collection('image_descriptions') db.create_collection('image_descriptions')
# 创建索引 # 创建索引
self.db.image_descriptions.create_index([('hash', 1)], unique=True) db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.image_descriptions.create_index([('type', 1)]) db.image_descriptions.create_index([('type', 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述 """从数据库获取图片描述
@@ -67,7 +65,7 @@ class ImageManager:
Returns: Returns:
Optional[str]: 描述文本如果不存在则返回None Optional[str]: 描述文本如果不存在则返回None
""" """
result= self.db.image_descriptions.find_one({ result= db.image_descriptions.find_one({
'hash': image_hash, 'hash': image_hash,
'type': description_type 'type': description_type
}) })
@@ -81,7 +79,7 @@ class ImageManager:
description: 描述文本 description: 描述文本
description_type: 描述类型 ('emoji''image') description_type: 描述类型 ('emoji''image')
""" """
self.db.image_descriptions.update_one( db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type}, {'hash': image_hash, 'type': description_type},
{ {
'$set': { '$set': {
@@ -124,7 +122,7 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查重 # 查重
existing = self.db.images.find_one({'hash': image_hash}) existing = db.images.find_one({'hash': image_hash})
if existing: if existing:
return existing['path'] return existing['path']
@@ -145,7 +143,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
self.db.images.insert_one(image_doc) db.images.insert_one(image_doc)
return file_path return file_path
@@ -162,7 +160,7 @@ class ImageManager:
""" """
try: try:
# 先查找是否已存在 # 先查找是否已存在
existing = self.db.images.find_one({'url': url}) existing = db.images.find_one({'url': url})
if existing: if existing:
return existing['path'] return existing['path']
@@ -206,7 +204,7 @@ class ImageManager:
Returns: Returns:
bool: 是否存在 bool: 是否存在
""" """
return self.db.images.find_one({'url': url}) is not None return db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool: def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在 """检查图像是否已存在
@@ -229,7 +227,7 @@ class ImageManager:
return False return False
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.images.find_one({'hash': image_hash}) is not None return db.images.find_one({'hash': image_hash}) is not None
except Exception as e: except Exception as e:
logger.error(f"检查哈希失败: {str(e)}") logger.error(f"检查哈希失败: {str(e)}")
@@ -273,7 +271,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
self.db.images.update_one( db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True
@@ -335,7 +333,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
self.db.images.update_one( db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True

View File

@@ -13,7 +13,7 @@ from loguru import logger
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path) sys.path.append(root_path)
from src.common.database import Database # 使用正确的导入语法 from src.common.database import db # 使用正确的导入语法
# 加载.env.dev文件 # 加载.env.dev文件
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev') env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env.dev')
@@ -23,7 +23,6 @@ load_dotenv(env_path)
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2): def connect_dot(self, concept1, concept2):
self.G.add_edge(concept1, concept2) self.G.add_edge(concept1, concept2)
@@ -96,7 +95,7 @@ class Memory_graph:
dot_data = { dot_data = {
"concept": node "concept": node
} }
self.db.store_memory_dots.insert_one(dot_data) db.store_memory_dots.insert_one(dot_data)
@property @property
def dots(self): def dots(self):
@@ -106,7 +105,7 @@ class Memory_graph:
def get_random_chat_from_db(self, length: int, timestamp: str): def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录 # 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = '' chat_text = ''
closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 closest_record = db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
logger.info( logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
@@ -115,7 +114,7 @@ class Memory_graph:
group_id = closest_record['group_id'] # 获取groupid group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同 # 获取该时间戳之后的length条消息且groupid相同
chat_record = list( chat_record = list(
self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit( db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
length)) length))
for record in chat_record: for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
@@ -130,50 +129,39 @@ class Memory_graph:
def save_graph_to_db(self): def save_graph_to_db(self):
# 清空现有的图数据 # 清空现有的图数据
self.db.graph_data.delete_many({}) db.graph_data.delete_many({})
# 保存节点 # 保存节点
for node in self.G.nodes(data=True): for node in self.G.nodes(data=True):
node_data = { node_data = {
'concept': node[0], 'concept': node[0],
'memory_items': node[1].get('memory_items', []) # 默认为空列表 'memory_items': node[1].get('memory_items', []) # 默认为空列表
} }
self.db.graph_data.nodes.insert_one(node_data) db.graph_data.nodes.insert_one(node_data)
# 保存边 # 保存边
for edge in self.G.edges(): for edge in self.G.edges():
edge_data = { edge_data = {
'source': edge[0], 'source': edge[0],
'target': edge[1] 'target': edge[1]
} }
self.db.graph_data.edges.insert_one(edge_data) db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self): def load_graph_from_db(self):
# 清空当前图 # 清空当前图
self.G.clear() self.G.clear()
# 加载节点 # 加载节点
nodes = self.db.graph_data.nodes.find() nodes = db.graph_data.nodes.find()
for node in nodes: for node in nodes:
memory_items = node.get('memory_items', []) memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items) self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边 # 加载边
edges = self.db.graph_data.edges.find() edges = db.graph_data.edges.find()
for edge in edges: for edge in edges:
self.G.add_edge(edge['source'], edge['target']) self.G.add_edge(edge['source'], edge['target'])
def main(): def main():
# 初始化数据库
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
memory_graph = Memory_graph() memory_graph = Memory_graph()
memory_graph.load_graph_from_db() memory_graph.load_graph_from_db()

View File

@@ -10,12 +10,12 @@ import networkx as nx
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from ...common.database import Database # 使用正确的导入语法 from ...common.database import db # 使用正确的导入语法
from ..chat.config import global_config from ..chat.config import global_config
from ..chat.utils import ( from ..chat.utils import (
calculate_information_content, calculate_information_content,
cosine_similarity, cosine_similarity,
get_cloest_chat_from_db, get_closest_chat_from_db,
text_to_vector, text_to_vector,
) )
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
@@ -23,7 +23,6 @@ from ..models.utils_model import LLM_request
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2): def connect_dot(self, concept1, concept2):
# 避免自连接 # 避免自连接
@@ -191,19 +190,19 @@ class Hippocampus:
# 短期1h 中期4h 长期24h # 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600) random_time = current_timestamp - random.randint(1, 3600)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
for _ in range(time_frequency.get('mid')): for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600, 3600 * 4) random_time = current_timestamp - random.randint(3600, 3600 * 4)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
for _ in range(time_frequency.get('far')): for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24) random_time = current_timestamp - random.randint(3600 * 4, 3600 * 24)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
@@ -349,7 +348,7 @@ class Hippocampus:
def sync_memory_to_db(self): def sync_memory_to_db(self):
"""检查并同步内存中的图结构与数据库""" """检查并同步内存中的图结构与数据库"""
# 获取数据库中所有节点和内存中所有节点 # 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.graph_data.nodes.find()) db_nodes = list(db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True)) memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找 # 转换数据库节点为字典格式,方便查找
@@ -377,7 +376,7 @@ class Hippocampus:
'created_time': created_time, 'created_time': created_time,
'last_modified': last_modified 'last_modified': last_modified
} }
self.memory_graph.db.graph_data.nodes.insert_one(node_data) db.graph_data.nodes.insert_one(node_data)
else: else:
# 获取数据库中节点的特征值 # 获取数据库中节点的特征值
db_node = db_nodes_dict[concept] db_node = db_nodes_dict[concept]
@@ -385,7 +384,7 @@ class Hippocampus:
# 如果特征值不同,则更新节点 # 如果特征值不同,则更新节点
if db_hash != memory_hash: if db_hash != memory_hash:
self.memory_graph.db.graph_data.nodes.update_one( db.graph_data.nodes.update_one(
{'concept': concept}, {'concept': concept},
{'$set': { {'$set': {
'memory_items': memory_items, 'memory_items': memory_items,
@@ -396,7 +395,7 @@ class Hippocampus:
) )
# 处理边的信息 # 处理边的信息
db_edges = list(self.memory_graph.db.graph_data.edges.find()) db_edges = list(db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges(data=True)) memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典 # 创建边的哈希值字典
@@ -428,11 +427,11 @@ class Hippocampus:
'created_time': created_time, 'created_time': created_time,
'last_modified': last_modified 'last_modified': last_modified
} }
self.memory_graph.db.graph_data.edges.insert_one(edge_data) db.graph_data.edges.insert_one(edge_data)
else: else:
# 检查边的特征值是否变化 # 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash: if db_edge_dict[edge_key]['hash'] != edge_hash:
self.memory_graph.db.graph_data.edges.update_one( db.graph_data.edges.update_one(
{'source': source, 'target': target}, {'source': source, 'target': target},
{'$set': { {'$set': {
'hash': edge_hash, 'hash': edge_hash,
@@ -451,7 +450,7 @@ class Hippocampus:
self.memory_graph.G.clear() self.memory_graph.G.clear()
# 从数据库加载所有节点 # 从数据库加载所有节点
nodes = list(self.memory_graph.db.graph_data.nodes.find()) nodes = list(db.graph_data.nodes.find())
for node in nodes: for node in nodes:
concept = node['concept'] concept = node['concept']
memory_items = node.get('memory_items', []) memory_items = node.get('memory_items', [])
@@ -468,7 +467,7 @@ class Hippocampus:
if 'last_modified' not in node: if 'last_modified' not in node:
update_data['last_modified'] = current_time update_data['last_modified'] = current_time
self.memory_graph.db.graph_data.nodes.update_one( db.graph_data.nodes.update_one(
{'concept': concept}, {'concept': concept},
{'$set': update_data} {'$set': update_data}
) )
@@ -485,7 +484,7 @@ class Hippocampus:
last_modified=last_modified) last_modified=last_modified)
# 从数据库加载所有边 # 从数据库加载所有边
edges = list(self.memory_graph.db.graph_data.edges.find()) edges = list(db.graph_data.edges.find())
for edge in edges: for edge in edges:
source = edge['source'] source = edge['source']
target = edge['target'] target = edge['target']
@@ -501,7 +500,7 @@ class Hippocampus:
if 'last_modified' not in edge: if 'last_modified' not in edge:
update_data['last_modified'] = current_time update_data['last_modified'] = current_time
self.memory_graph.db.graph_data.edges.update_one( db.graph_data.edges.update_one(
{'source': source, 'target': target}, {'source': source, 'target': target},
{'$set': update_data} {'$set': update_data}
) )

View File

@@ -19,7 +19,7 @@ import jieba
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path) sys.path.append(root_path)
from src.common.database import Database from src.common.database import db
from src.plugins.memory_system.offline_llm import LLMModel from src.plugins.memory_system.offline_llm import LLMModel
# 获取当前文件的目录 # 获取当前文件的目录
@@ -49,7 +49,7 @@ def calculate_information_content(text):
return entropy return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str): def get_closest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数 """从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
Returns: Returns:
@@ -91,7 +91,6 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2): def connect_dot(self, concept1, concept2):
# 如果边已存在,增加 strength # 如果边已存在,增加 strength
@@ -186,19 +185,19 @@ class Hippocampus:
# 短期1h 中期4h 长期24h # 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600*4) random_time = current_timestamp - random.randint(1, 3600*4)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
for _ in range(time_frequency.get('mid')): for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600*4, 3600*24) random_time = current_timestamp - random.randint(3600*4, 3600*24)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
for _ in range(time_frequency.get('far')): for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600*24, 3600*24*7) random_time = current_timestamp - random.randint(3600*24, 3600*24*7)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
@@ -323,7 +322,7 @@ class Hippocampus:
self.memory_graph.G.clear() self.memory_graph.G.clear()
# 从数据库加载所有节点 # 从数据库加载所有节点
nodes = self.memory_graph.db.graph_data.nodes.find() nodes = db.graph_data.nodes.find()
for node in nodes: for node in nodes:
concept = node['concept'] concept = node['concept']
memory_items = node.get('memory_items', []) memory_items = node.get('memory_items', [])
@@ -334,7 +333,7 @@ class Hippocampus:
self.memory_graph.G.add_node(concept, memory_items=memory_items) self.memory_graph.G.add_node(concept, memory_items=memory_items)
# 从数据库加载所有边 # 从数据库加载所有边
edges = self.memory_graph.db.graph_data.edges.find() edges = db.graph_data.edges.find()
for edge in edges: for edge in edges:
source = edge['source'] source = edge['source']
target = edge['target'] target = edge['target']
@@ -371,7 +370,7 @@ class Hippocampus:
使用特征值(哈希值)快速判断是否需要更新 使用特征值(哈希值)快速判断是否需要更新
""" """
# 获取数据库中所有节点和内存中所有节点 # 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.graph_data.nodes.find()) db_nodes = list(db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True)) memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找 # 转换数据库节点为字典格式,方便查找
@@ -394,7 +393,7 @@ class Hippocampus:
'memory_items': memory_items, 'memory_items': memory_items,
'hash': memory_hash 'hash': memory_hash
} }
self.memory_graph.db.graph_data.nodes.insert_one(node_data) db.graph_data.nodes.insert_one(node_data)
else: else:
# 获取数据库中节点的特征值 # 获取数据库中节点的特征值
db_node = db_nodes_dict[concept] db_node = db_nodes_dict[concept]
@@ -403,7 +402,7 @@ class Hippocampus:
# 如果特征值不同,则更新节点 # 如果特征值不同,则更新节点
if db_hash != memory_hash: if db_hash != memory_hash:
# logger.info(f"更新节点内容: {concept}") # logger.info(f"更新节点内容: {concept}")
self.memory_graph.db.graph_data.nodes.update_one( db.graph_data.nodes.update_one(
{'concept': concept}, {'concept': concept},
{'$set': { {'$set': {
'memory_items': memory_items, 'memory_items': memory_items,
@@ -416,10 +415,10 @@ class Hippocampus:
for db_node in db_nodes: for db_node in db_nodes:
if db_node['concept'] not in memory_concepts: if db_node['concept'] not in memory_concepts:
# logger.info(f"删除多余节点: {db_node['concept']}") # logger.info(f"删除多余节点: {db_node['concept']}")
self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']}) db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息 # 处理边的信息
db_edges = list(self.memory_graph.db.graph_data.edges.find()) db_edges = list(db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges()) memory_edges = list(self.memory_graph.G.edges())
# 创建边的哈希值字典 # 创建边的哈希值字典
@@ -445,12 +444,12 @@ class Hippocampus:
'num': 1, 'num': 1,
'hash': edge_hash 'hash': edge_hash
} }
self.memory_graph.db.graph_data.edges.insert_one(edge_data) db.graph_data.edges.insert_one(edge_data)
else: else:
# 检查边的特征值是否变化 # 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash: if db_edge_dict[edge_key]['hash'] != edge_hash:
logger.info(f"更新边: {source} - {target}") logger.info(f"更新边: {source} - {target}")
self.memory_graph.db.graph_data.edges.update_one( db.graph_data.edges.update_one(
{'source': source, 'target': target}, {'source': source, 'target': target},
{'$set': {'hash': edge_hash}} {'$set': {'hash': edge_hash}}
) )
@@ -461,7 +460,7 @@ class Hippocampus:
if edge_key not in memory_edge_set: if edge_key not in memory_edge_set:
source, target = edge_key source, target = edge_key
logger.info(f"删除多余边: {source} - {target}") logger.info(f"删除多余边: {source} - {target}")
self.memory_graph.db.graph_data.edges.delete_one({ db.graph_data.edges.delete_one({
'source': source, 'source': source,
'target': target 'target': target
}) })
@@ -487,9 +486,9 @@ class Hippocampus:
topic: 要删除的节点概念 topic: 要删除的节点概念
""" """
# 删除节点 # 删除节点
self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic}) db.graph_data.nodes.delete_one({'concept': topic})
# 删除所有涉及该节点的边 # 删除所有涉及该节点的边
self.memory_graph.db.graph_data.edges.delete_many({ db.graph_data.edges.delete_many({
'$or': [ '$or': [
{'source': topic}, {'source': topic},
{'target': topic} {'target': topic}
@@ -902,17 +901,6 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
plt.show() plt.show()
async def main(): async def main():
# 初始化数据库
logger.info("正在初始化数据库连接...")
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
start_time = time.time() start_time = time.time()
test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False} test_pare = {'do_build_memory':False,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}

View File

@@ -38,7 +38,7 @@ import jieba
# from chat.config import global_config # from chat.config import global_config
sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径 sys.path.append("C:/GitHub/MaiMBot") # 添加项目根目录到 Python 路径
from src.common.database import Database from src.common.database import db
from src.plugins.memory_system.offline_llm import LLMModel from src.plugins.memory_system.offline_llm import LLMModel
# 获取当前文件的目录 # 获取当前文件的目录
@@ -56,45 +56,6 @@ else:
logger.warning(f"未找到环境变量文件: {env_path}") logger.warning(f"未找到环境变量文件: {env_path}")
logger.info("将使用默认配置") logger.info("将使用默认配置")
class Database:
_instance = None
db = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
if not Database.db:
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
@classmethod
def initialize(cls, host, port, db_name, username=None, password=None, auth_source="admin"):
try:
if username and password:
uri = f"mongodb://{username}:{password}@{host}:{port}/{db_name}?authSource={auth_source}"
else:
uri = f"mongodb://{host}:{port}"
client = pymongo.MongoClient(uri)
cls.db = client[db_name]
# 测试连接
client.server_info()
logger.success("MongoDB连接成功!")
except Exception as e:
logger.error(f"初始化MongoDB失败: {str(e)}")
raise
def calculate_information_content(text): def calculate_information_content(text):
"""计算文本的信息量(熵)""" """计算文本的信息量(熵)"""
@@ -108,7 +69,7 @@ def calculate_information_content(text):
return entropy return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str): def get_closest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数 """从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
Returns: Returns:
@@ -163,7 +124,7 @@ class Memory_cortex:
default_time = datetime.datetime.now().timestamp() default_time = datetime.datetime.now().timestamp()
# 从数据库加载所有节点 # 从数据库加载所有节点
nodes = self.memory_graph.db.graph_data.nodes.find() nodes = db.graph_data.nodes.find()
for node in nodes: for node in nodes:
concept = node['concept'] concept = node['concept']
memory_items = node.get('memory_items', []) memory_items = node.get('memory_items', [])
@@ -180,7 +141,7 @@ class Memory_cortex:
created_time = default_time created_time = default_time
last_modified = default_time last_modified = default_time
# 更新数据库中的节点 # 更新数据库中的节点
self.memory_graph.db.graph_data.nodes.update_one( db.graph_data.nodes.update_one(
{'concept': concept}, {'concept': concept},
{'$set': { {'$set': {
'created_time': created_time, 'created_time': created_time,
@@ -196,7 +157,7 @@ class Memory_cortex:
last_modified=last_modified) last_modified=last_modified)
# 从数据库加载所有边 # 从数据库加载所有边
edges = self.memory_graph.db.graph_data.edges.find() edges = db.graph_data.edges.find()
for edge in edges: for edge in edges:
source = edge['source'] source = edge['source']
target = edge['target'] target = edge['target']
@@ -212,7 +173,7 @@ class Memory_cortex:
created_time = default_time created_time = default_time
last_modified = default_time last_modified = default_time
# 更新数据库中的边 # 更新数据库中的边
self.memory_graph.db.graph_data.edges.update_one( db.graph_data.edges.update_one(
{'source': source, 'target': target}, {'source': source, 'target': target},
{'$set': { {'$set': {
'created_time': created_time, 'created_time': created_time,
@@ -256,7 +217,7 @@ class Memory_cortex:
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
# 获取数据库中所有节点和内存中所有节点 # 获取数据库中所有节点和内存中所有节点
db_nodes = list(self.memory_graph.db.graph_data.nodes.find()) db_nodes = list(db.graph_data.nodes.find())
memory_nodes = list(self.memory_graph.G.nodes(data=True)) memory_nodes = list(self.memory_graph.G.nodes(data=True))
# 转换数据库节点为字典格式,方便查找 # 转换数据库节点为字典格式,方便查找
@@ -280,7 +241,7 @@ class Memory_cortex:
'created_time': data.get('created_time', current_time), 'created_time': data.get('created_time', current_time),
'last_modified': data.get('last_modified', current_time) 'last_modified': data.get('last_modified', current_time)
} }
self.memory_graph.db.graph_data.nodes.insert_one(node_data) db.graph_data.nodes.insert_one(node_data)
else: else:
# 获取数据库中节点的特征值 # 获取数据库中节点的特征值
db_node = db_nodes_dict[concept] db_node = db_nodes_dict[concept]
@@ -288,7 +249,7 @@ class Memory_cortex:
# 如果特征值不同,则更新节点 # 如果特征值不同,则更新节点
if db_hash != memory_hash: if db_hash != memory_hash:
self.memory_graph.db.graph_data.nodes.update_one( db.graph_data.nodes.update_one(
{'concept': concept}, {'concept': concept},
{'$set': { {'$set': {
'memory_items': memory_items, 'memory_items': memory_items,
@@ -301,10 +262,10 @@ class Memory_cortex:
memory_concepts = set(node[0] for node in memory_nodes) memory_concepts = set(node[0] for node in memory_nodes)
for db_node in db_nodes: for db_node in db_nodes:
if db_node['concept'] not in memory_concepts: if db_node['concept'] not in memory_concepts:
self.memory_graph.db.graph_data.nodes.delete_one({'concept': db_node['concept']}) db.graph_data.nodes.delete_one({'concept': db_node['concept']})
# 处理边的信息 # 处理边的信息
db_edges = list(self.memory_graph.db.graph_data.edges.find()) db_edges = list(db.graph_data.edges.find())
memory_edges = list(self.memory_graph.G.edges(data=True)) memory_edges = list(self.memory_graph.G.edges(data=True))
# 创建边的哈希值字典 # 创建边的哈希值字典
@@ -332,11 +293,11 @@ class Memory_cortex:
'created_time': data.get('created_time', current_time), 'created_time': data.get('created_time', current_time),
'last_modified': data.get('last_modified', current_time) 'last_modified': data.get('last_modified', current_time)
} }
self.memory_graph.db.graph_data.edges.insert_one(edge_data) db.graph_data.edges.insert_one(edge_data)
else: else:
# 检查边的特征值是否变化 # 检查边的特征值是否变化
if db_edge_dict[edge_key]['hash'] != edge_hash: if db_edge_dict[edge_key]['hash'] != edge_hash:
self.memory_graph.db.graph_data.edges.update_one( db.graph_data.edges.update_one(
{'source': source, 'target': target}, {'source': source, 'target': target},
{'$set': { {'$set': {
'hash': edge_hash, 'hash': edge_hash,
@@ -350,7 +311,7 @@ class Memory_cortex:
for edge_key in db_edge_dict: for edge_key in db_edge_dict:
if edge_key not in memory_edge_set: if edge_key not in memory_edge_set:
source, target = edge_key source, target = edge_key
self.memory_graph.db.graph_data.edges.delete_one({ db.graph_data.edges.delete_one({
'source': source, 'source': source,
'target': target 'target': target
}) })
@@ -365,9 +326,9 @@ class Memory_cortex:
topic: 要删除的节点概念 topic: 要删除的节点概念
""" """
# 删除节点 # 删除节点
self.memory_graph.db.graph_data.nodes.delete_one({'concept': topic}) db.graph_data.nodes.delete_one({'concept': topic})
# 删除所有涉及该节点的边 # 删除所有涉及该节点的边
self.memory_graph.db.graph_data.edges.delete_many({ db.graph_data.edges.delete_many({
'$or': [ '$or': [
{'source': topic}, {'source': topic},
{'target': topic} {'target': topic}
@@ -377,7 +338,6 @@ class Memory_cortex:
class Memory_graph: class Memory_graph:
def __init__(self): def __init__(self):
self.G = nx.Graph() # 使用 networkx 的图结构 self.G = nx.Graph() # 使用 networkx 的图结构
self.db = Database.get_instance()
def connect_dot(self, concept1, concept2): def connect_dot(self, concept1, concept2):
# 避免自连接 # 避免自连接
@@ -492,19 +452,19 @@ class Hippocampus:
# 短期1h 中期4h 长期24h # 短期1h 中期4h 长期24h
for _ in range(time_frequency.get('near')): for _ in range(time_frequency.get('near')):
random_time = current_timestamp - random.randint(1, 3600*4) random_time = current_timestamp - random.randint(1, 3600*4)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
for _ in range(time_frequency.get('mid')): for _ in range(time_frequency.get('mid')):
random_time = current_timestamp - random.randint(3600*4, 3600*24) random_time = current_timestamp - random.randint(3600*4, 3600*24)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
for _ in range(time_frequency.get('far')): for _ in range(time_frequency.get('far')):
random_time = current_timestamp - random.randint(3600*24, 3600*24*7) random_time = current_timestamp - random.randint(3600*24, 3600*24*7)
messages = get_cloest_chat_from_db(db=self.memory_graph.db, length=chat_size, timestamp=random_time) messages = get_closest_chat_from_db(length=chat_size, timestamp=random_time)
if messages: if messages:
chat_samples.append(messages) chat_samples.append(messages)
@@ -1134,7 +1094,6 @@ def visualize_graph_lite(memory_graph: Memory_graph, color_by_memory: bool = Fal
async def main(): async def main():
# 初始化数据库 # 初始化数据库
logger.info("正在初始化数据库连接...") logger.info("正在初始化数据库连接...")
db = Database.get_instance()
start_time = time.time() start_time = time.time()
test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False} test_pare = {'do_build_memory':True,'do_forget_topic':False,'do_visualize_graph':True,'do_query':False,'do_merge_memory':False}

View File

@@ -10,7 +10,7 @@ from nonebot import get_driver
import base64 import base64
from PIL import Image from PIL import Image
import io import io
from ...common.database import Database from ...common.database import db
from ..chat.config import global_config from ..chat.config import global_config
driver = get_driver() driver = get_driver()
@@ -34,17 +34,16 @@ class LLM_request:
self.pri_out = model.get("pri_out", 0) self.pri_out = model.get("pri_out", 0)
# 获取数据库实例 # 获取数据库实例
self.db = Database.get_instance()
self._init_database() self._init_database()
def _init_database(self): def _init_database(self):
"""初始化数据库集合""" """初始化数据库集合"""
try: try:
# 创建llm_usage集合的索引 # 创建llm_usage集合的索引
self.db.llm_usage.create_index([("timestamp", 1)]) db.llm_usage.create_index([("timestamp", 1)])
self.db.llm_usage.create_index([("model_name", 1)]) db.llm_usage.create_index([("model_name", 1)])
self.db.llm_usage.create_index([("user_id", 1)]) db.llm_usage.create_index([("user_id", 1)])
self.db.llm_usage.create_index([("request_type", 1)]) db.llm_usage.create_index([("request_type", 1)])
except Exception: except Exception:
logger.error("创建数据库索引失败") logger.error("创建数据库索引失败")
@@ -73,7 +72,7 @@ class LLM_request:
"status": "success", "status": "success",
"timestamp": datetime.now() "timestamp": datetime.now()
} }
self.db.llm_usage.insert_one(usage_data) db.llm_usage.insert_one(usage_data)
logger.info( logger.info(
f"Token使用情况 - 模型: {self.model_name}, " f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, " f"用户: {user_id}, 类型: {request_type}, "

View File

@@ -8,7 +8,7 @@ from nonebot import get_driver
from src.plugins.chat.config import global_config from src.plugins.chat.config import global_config
from ...common.database import Database # 使用正确的导入语法 from ...common.database import db # 使用正确的导入语法
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
driver = get_driver() driver = get_driver()
@@ -19,7 +19,6 @@ class ScheduleGenerator:
# 根据global_config.llm_normal这一字典配置指定模型 # 根据global_config.llm_normal这一字典配置指定模型
# self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9) # self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9)
self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9) self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9)
self.db = Database.get_instance()
self.today_schedule_text = "" self.today_schedule_text = ""
self.today_schedule = {} self.today_schedule = {}
self.tomorrow_schedule_text = "" self.tomorrow_schedule_text = ""
@@ -46,7 +45,7 @@ class ScheduleGenerator:
schedule_text = str schedule_text = str
existing_schedule = self.db.schedule.find_one({"date": date_str}) existing_schedule = db.schedule.find_one({"date": date_str})
if existing_schedule: if existing_schedule:
logger.debug(f"{date_str}的日程已存在:") logger.debug(f"{date_str}的日程已存在:")
schedule_text = existing_schedule["schedule"] schedule_text = existing_schedule["schedule"]
@@ -63,7 +62,7 @@ class ScheduleGenerator:
try: try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt) schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
self.db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
except Exception as e: except Exception as e:
logger.error(f"生成日程失败: {str(e)}") logger.error(f"生成日程失败: {str(e)}")
schedule_text = "生成日程时出错了" schedule_text = "生成日程时出错了"
@@ -143,7 +142,7 @@ class ScheduleGenerator:
"""打印完整的日程安排""" """打印完整的日程安排"""
if not self._parse_schedule(self.today_schedule_text): if not self._parse_schedule(self.today_schedule_text):
logger.warning("今日日程有误,将在下次运行时重新生成") logger.warning("今日日程有误,将在下次运行时重新生成")
self.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else: else:
logger.info("=== 今日日程安排 ===") logger.info("=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items(): for time_str, activity in self.today_schedule.items():

View File

@@ -5,7 +5,7 @@ from datetime import datetime, timedelta
from typing import Any, Dict from typing import Any, Dict
from loguru import logger from loguru import logger
from ...common.database import Database from ...common.database import db
class LLMStatistics: class LLMStatistics:
@@ -15,7 +15,6 @@ class LLMStatistics:
Args: Args:
output_file: 统计结果输出文件路径 output_file: 统计结果输出文件路径
""" """
self.db = Database.get_instance()
self.output_file = output_file self.output_file = output_file
self.running = False self.running = False
self.stats_thread = None self.stats_thread = None
@@ -53,7 +52,7 @@ class LLMStatistics:
"costs_by_model": defaultdict(float) "costs_by_model": defaultdict(float)
} }
cursor = self.db.llm_usage.find({ cursor = db.llm_usage.find({
"timestamp": {"$gte": start_time} "timestamp": {"$gte": start_time}
}) })

View File

@@ -14,7 +14,7 @@ root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path) sys.path.append(root_path)
# 现在可以导入src模块 # 现在可以导入src模块
from src.common.database import Database from src.common.database import db
# 加载根目录下的env.edv文件 # 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.prod") env_path = os.path.join(root_path, ".env.prod")
@@ -24,18 +24,6 @@ load_dotenv(env_path)
class KnowledgeLibrary: class KnowledgeLibrary:
def __init__(self): def __init__(self):
# 初始化数据库连接
if Database._instance is None:
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info" self.raw_info_dir = "data/raw_info"
self._ensure_dirs() self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY") self.api_key = os.getenv("SILICONFLOW_KEY")
@@ -176,7 +164,7 @@ class KnowledgeLibrary:
try: try:
current_hash = self.calculate_file_hash(file_path) current_hash = self.calculate_file_hash(file_path)
processed_record = self.db.processed_files.find_one({"file_path": file_path}) processed_record = db.processed_files.find_one({"file_path": file_path})
if processed_record: if processed_record:
if processed_record.get("hash") == current_hash: if processed_record.get("hash") == current_hash:
@@ -197,14 +185,14 @@ class KnowledgeLibrary:
"split_length": knowledge_length, "split_length": knowledge_length,
"created_at": datetime.now() "created_at": datetime.now()
} }
self.db.knowledges.insert_one(knowledge) db.knowledges.insert_one(knowledge)
result["chunks_processed"] += 1 result["chunks_processed"] += 1
split_by = processed_record.get("split_by", []) if processed_record else [] split_by = processed_record.get("split_by", []) if processed_record else []
if knowledge_length not in split_by: if knowledge_length not in split_by:
split_by.append(knowledge_length) split_by.append(knowledge_length)
self.db.knowledges.processed_files.update_one( db.knowledges.processed_files.update_one(
{"file_path": file_path}, {"file_path": file_path},
{ {
"$set": { "$set": {
@@ -322,7 +310,7 @@ class KnowledgeLibrary:
{"$project": {"content": 1, "similarity": 1, "file_path": 1}} {"$project": {"content": 1, "similarity": 1, "file_path": 1}}
] ]
results = list(self.db.knowledges.aggregate(pipeline)) results = list(db.knowledges.aggregate(pipeline))
return results return results
# 创建单例实例 # 创建单例实例
@@ -346,7 +334,7 @@ if __name__ == "__main__":
elif choice == '2': elif choice == '2':
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower() confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
if confirm == 'y': if confirm == 'y':
knowledge_library.db.knowledges.delete_many({}) db.knowledges.delete_many({})
console.print("[green]已清空所有知识![/green]") console.print("[green]已清空所有知识![/green]")
continue continue
elif choice == '1': elif choice == '1':