重构数据库访问,替换为统一的数据库实例引用

This commit is contained in:
晴猫
2025-03-12 22:27:59 +09:00
parent 49082267bb
commit 8be087dcad
19 changed files with 138 additions and 284 deletions

View File

@@ -7,7 +7,6 @@ from nonebot import get_driver, on_message, require
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment,MessageEvent
from nonebot.typing import T_State
from ...common.database import Database
from ..moods.moods import MoodManager # 导入情绪管理器
from ..schedule.schedule_generator import bot_schedule
from ..utils.statistic import LLMStatistics

View File

@@ -6,7 +6,7 @@ from typing import Dict, Optional
from loguru import logger
from ...common.database import Database
from ...common.database import db
from .message_base import GroupInfo, UserInfo
@@ -83,7 +83,6 @@ class ChatManager:
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.db = Database.get_instance()
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
@@ -111,11 +110,11 @@ class ChatManager:
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.list_collection_names():
self.db.create_collection("chat_streams")
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
# 创建索引
self.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.chat_streams.create_index(
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
@@ -168,7 +167,7 @@ class ChatManager:
return stream
# 检查数据库中是否存在
data = self.db.chat_streams.find_one({"stream_id": stream_id})
data = db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
@@ -204,7 +203,7 @@ class ChatManager:
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.chat_streams.update_one(
db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
@@ -216,7 +215,7 @@ class ChatManager:
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.chat_streams.find({})
all_streams = db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream

View File

@@ -12,7 +12,7 @@ import io
from loguru import logger
from nonebot import get_driver
from ...common.database import Database
from ...common.database import db
from ..chat.config import global_config
from ..chat.utils import get_embedding
from ..chat.utils_image import ImageManager, image_path_to_base64
@@ -30,12 +30,10 @@ class EmojiManager:
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
self.db = Database.get_instance()
self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60,
@@ -50,7 +48,6 @@ class EmojiManager:
"""初始化数据库连接和表情目录"""
if not self._initialized:
try:
self.db = Database.get_instance()
self._ensure_emoji_collection()
self._ensure_emoji_dir()
self._initialized = True
@@ -78,16 +75,16 @@ class EmojiManager:
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if 'emoji' not in self.db.list_collection_names():
self.db.create_collection('emoji')
self.db.emoji.create_index([('embedding', '2dsphere')])
self.db.emoji.create_index([('filename', 1)], unique=True)
if 'emoji' not in db.list_collection_names():
db.create_collection('emoji')
db.emoji.create_index([('embedding', '2dsphere')])
db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str):
"""记录表情使用次数"""
try:
self._ensure_db()
self.db.emoji.update_one(
db.emoji.update_one(
{'_id': emoji_id},
{'$inc': {'usage_count': 1}}
)
@@ -121,7 +118,7 @@ class EmojiManager:
try:
# 获取所有表情包
all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
all_emojis = list(db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
if not all_emojis:
logger.warning("数据库中没有任何表情包")
@@ -159,7 +156,7 @@ class EmojiManager:
if selected_emoji and 'path' in selected_emoji:
# 更新使用次数
self.db.emoji.update_one(
db.emoji.update_one(
{'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}}
)
@@ -241,14 +238,14 @@ class EmojiManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 检查是否已经注册过
existing_emoji = self.db['emoji'].find_one({'filename': filename})
existing_emoji = db['emoji'].find_one({'filename': filename})
description = None
if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription')
# 检查是否在images集合中存在
existing_image = image_manager.db.images.find_one({'hash': image_hash})
existing_image = db.images.find_one({'hash': image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
@@ -258,7 +255,7 @@ class EmojiManager:
'description': description,
'timestamp': int(time.time())
}
image_manager.db.images.update_one(
db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
@@ -307,7 +304,7 @@ class EmojiManager:
}
# 保存到emoji数据库
self.db['emoji'].insert_one(emoji_record)
db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}")
@@ -320,7 +317,7 @@ class EmojiManager:
'description': description,
'timestamp': int(time.time())
}
image_manager.db.images.update_one(
db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
@@ -348,7 +345,7 @@ class EmojiManager:
try:
self._ensure_db()
# 获取所有表情包记录
all_emojis = list(self.db.emoji.find())
all_emojis = list(db.emoji.find())
removed_count = 0
total_count = len(all_emojis)
@@ -356,13 +353,13 @@ class EmojiManager:
try:
if 'path' not in emoji:
logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
self.db.emoji.delete_one({'_id': emoji['_id']})
db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.emoji.delete_one({'_id': emoji['_id']})
db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1
continue
@@ -370,7 +367,7 @@ class EmojiManager:
if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录
result = self.db.emoji.delete_one({'_id': emoji['_id']})
result = db.emoji.delete_one({'_id': emoji['_id']})
if result.deleted_count > 0:
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1
@@ -381,7 +378,7 @@ class EmojiManager:
continue
# 验证清理结果
remaining_count = self.db.emoji.count_documents({})
remaining_count = db.emoji.count_documents({})
if removed_count > 0:
logger.success(f"已清理 {removed_count} 个失效的表情包记录")
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")

View File

@@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union
from nonebot import get_driver
from loguru import logger
from ...common.database import Database
from ...common.database import db
from ..models.utils_model import LLM_request
from .config import global_config
from .message import MessageRecv, MessageThinking, Message
@@ -34,7 +34,6 @@ class ResponseGenerator:
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.db = Database.get_instance()
self.current_model_type = "r1" # 默认使用 R1
async def generate_response(
@@ -154,7 +153,7 @@ class ResponseGenerator:
reasoning_content: str,
):
"""保存对话记录到数据库"""
self.db.reasoning_logs.insert_one(
db.reasoning_logs.insert_one(
{
"time": time.time(),
"chat_id": message.chat_stream.stream_id,
@@ -211,7 +210,6 @@ class ResponseGenerator:
class InitiativeMessageGenerate:
def __init__(self):
self.db = Database.get_instance()
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
self.model_r1_distill = LLM_request(

View File

@@ -3,7 +3,7 @@ import time
from typing import Optional
from loguru import logger
from ...common.database import Database
from ...common.database import db
from ..memory_system.memory import hippocampus, memory_graph
from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule
@@ -16,7 +16,6 @@ class PromptBuilder:
def __init__(self):
self.prompt_built = ''
self.activate_messages = ''
self.db = Database.get_instance()
@@ -76,7 +75,7 @@ class PromptBuilder:
chat_in_group=True
chat_talking_prompt = ''
if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_talking_prompt = get_recent_group_detailed_plain_text(stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_stream=chat_manager.get_stream(stream_id)
if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
@@ -199,7 +198,7 @@ class PromptBuilder:
chat_talking_prompt = ''
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
chat_talking_prompt = get_recent_group_detailed_plain_text(group_id,
limit=global_config.MAX_CONTEXT_SIZE,
combine=True)
@@ -311,7 +310,7 @@ class PromptBuilder:
{"$project": {"content": 1, "similarity": 1}}
]
results = list(self.db.knowledges.aggregate(pipeline))
results = list(db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results:

View File

@@ -2,7 +2,7 @@ import asyncio
from typing import Optional
from loguru import logger
from ...common.database import Database
from ...common.database import db
from .message_base import UserInfo
from .chat_stream import ChatStream
@@ -167,14 +167,12 @@ class RelationshipManager:
async def load_all_relationships(self):
"""加载所有关系对象"""
db = Database.get_instance()
all_relationships = db.relationships.find({})
for data in all_relationships:
await self.load_relationship(data)
async def _start_relationship_manager(self):
"""每5分钟自动保存一次关系数据"""
db = Database.get_instance()
# 获取所有关系记录
all_relationships = db.relationships.find({})
# 依次加载每条记录
@@ -205,7 +203,6 @@ class RelationshipManager:
age = relationship.age
saved = relationship.saved
db = Database.get_instance()
db.relationships.update_one(
{'user_id': user_id, 'platform': platform},
{'$set': {

View File

@@ -1,15 +1,12 @@
from typing import Optional, Union
from ...common.database import Database
from ...common.database import db
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from loguru import logger
class MessageStorage:
def __init__(self):
self.db = Database.get_instance()
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库"""
try:
@@ -23,7 +20,7 @@ class MessageStorage:
"detailed_plain_text": message.detailed_plain_text,
"topic": topic,
}
self.db.messages.insert_one(message_data)
db.messages.insert_one(message_data)
except Exception:
logger.exception("存储消息失败")

View File

@@ -16,6 +16,7 @@ from .message import MessageRecv,Message
from .message_base import UserInfo
from .chat_stream import ChatStream
from ..moods.moods import MoodManager
from ...common.database import db
driver = get_driver()
config = driver.config
@@ -76,11 +77,10 @@ def calculate_information_content(text):
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
def get_cloest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录
Args:
db: 数据库实例
length: 要获取的消息数量
timestamp: 时间戳
@@ -115,11 +115,10 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
return []
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
async def get_recent_group_messages(chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
Args:
db: Database实例
group_id: 群组ID
limit: 获取消息数量默认12条
@@ -161,7 +160,7 @@ async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
return message_objects
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
def get_recent_group_detailed_plain_text(chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.messages.find(
{"chat_id": chat_stream_id},
{

View File

@@ -10,7 +10,7 @@ import io
from loguru import logger
from nonebot import get_driver
from ...common.database import Database
from ...common.database import db
from ..chat.config import global_config
from ..models.utils_model import LLM_request
driver = get_driver()
@@ -23,13 +23,11 @@ class ImageManager:
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.db = Database.get_instance()
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
@@ -42,20 +40,20 @@ class ImageManager:
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if 'images' not in self.db.list_collection_names():
self.db.create_collection('images')
if 'images' not in db.list_collection_names():
db.create_collection('images')
# 创建索引
self.db.images.create_index([('hash', 1)], unique=True)
self.db.images.create_index([('url', 1)])
self.db.images.create_index([('path', 1)])
db.images.create_index([('hash', 1)], unique=True)
db.images.create_index([('url', 1)])
db.images.create_index([('path', 1)])
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.list_collection_names():
self.db.create_collection('image_descriptions')
if 'image_descriptions' not in db.list_collection_names():
db.create_collection('image_descriptions')
# 创建索引
self.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.image_descriptions.create_index([('type', 1)])
db.image_descriptions.create_index([('hash', 1)], unique=True)
db.image_descriptions.create_index([('type', 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -67,7 +65,7 @@ class ImageManager:
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result= self.db.image_descriptions.find_one({
result= db.image_descriptions.find_one({
'hash': image_hash,
'type': description_type
})
@@ -81,7 +79,7 @@ class ImageManager:
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
self.db.image_descriptions.update_one(
db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type},
{
'$set': {
@@ -124,7 +122,7 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查重
existing = self.db.images.find_one({'hash': image_hash})
existing = db.images.find_one({'hash': image_hash})
if existing:
return existing['path']
@@ -145,7 +143,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.images.insert_one(image_doc)
db.images.insert_one(image_doc)
return file_path
@@ -162,7 +160,7 @@ class ImageManager:
"""
try:
# 先查找是否已存在
existing = self.db.images.find_one({'url': url})
existing = db.images.find_one({'url': url})
if existing:
return existing['path']
@@ -206,7 +204,7 @@ class ImageManager:
Returns:
bool: 是否存在
"""
return self.db.images.find_one({'url': url}) is not None
return db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
@@ -229,7 +227,7 @@ class ImageManager:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.images.find_one({'hash': image_hash}) is not None
return db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
@@ -273,7 +271,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.images.update_one(
db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
@@ -335,7 +333,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
self.db.images.update_one(
db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True