Merge pull request #237 from Rikki-Zero/debug
refactor: 修复database单例多次初始化的问题,改变instance默认返回实例的类型,缩短db相关函数调用时的object名
This commit is contained in:
15
bot.py
15
bot.py
@@ -12,6 +12,8 @@ from loguru import logger
|
|||||||
from nonebot.adapters.onebot.v11 import Adapter
|
from nonebot.adapters.onebot.v11 import Adapter
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
|
from src.common.database import Database
|
||||||
|
|
||||||
# 获取没有加载env时的环境变量
|
# 获取没有加载env时的环境变量
|
||||||
env_mask = {key: os.getenv(key) for key in os.environ}
|
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||||
|
|
||||||
@@ -96,6 +98,17 @@ def load_env():
|
|||||||
logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
|
logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
|
||||||
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
|
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
|
||||||
|
|
||||||
|
def init_database():
|
||||||
|
Database.initialize(
|
||||||
|
uri=os.getenv("MONGODB_URI"),
|
||||||
|
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||||
|
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||||
|
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||||
|
username=os.getenv("MONGODB_USERNAME"),
|
||||||
|
password=os.getenv("MONGODB_PASSWORD"),
|
||||||
|
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_logger():
|
def load_logger():
|
||||||
logger.remove() # 移除默认配置
|
logger.remove() # 移除默认配置
|
||||||
@@ -198,6 +211,7 @@ def raw_main():
|
|||||||
init_config()
|
init_config()
|
||||||
init_env()
|
init_env()
|
||||||
load_env()
|
load_env()
|
||||||
|
init_database() # 加载完成环境后初始化database
|
||||||
load_logger()
|
load_logger()
|
||||||
|
|
||||||
env_config = {key: os.getenv(key) for key in os.environ}
|
env_config = {key: os.getenv(key) for key in os.environ}
|
||||||
@@ -223,7 +237,6 @@ def raw_main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw_main()
|
raw_main()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
from pymongo.database import Database as MongoDatabase
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
_instance: Optional["Database"] = None
|
_instance: Optional["Database"] = None
|
||||||
@@ -25,7 +26,7 @@ class Database:
|
|||||||
else:
|
else:
|
||||||
# 否则使用无认证连接
|
# 否则使用无认证连接
|
||||||
self.client = MongoClient(host, port)
|
self.client = MongoClient(host, port)
|
||||||
self.db = self.client[db_name]
|
self.db: MongoDatabase = self.client[db_name]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(
|
def initialize(
|
||||||
@@ -37,15 +38,36 @@ class Database:
|
|||||||
password: Optional[str] = None,
|
password: Optional[str] = None,
|
||||||
auth_source: Optional[str] = None,
|
auth_source: Optional[str] = None,
|
||||||
uri: Optional[str] = None,
|
uri: Optional[str] = None,
|
||||||
) -> "Database":
|
) -> MongoDatabase:
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = cls(
|
cls._instance = cls(
|
||||||
host, port, db_name, username, password, auth_source, uri
|
host, port, db_name, username, password, auth_source, uri
|
||||||
)
|
)
|
||||||
return cls._instance
|
return cls._instance.db
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls) -> "Database":
|
def get_instance(cls) -> MongoDatabase:
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
raise RuntimeError("Database not initialized")
|
raise RuntimeError("Database not initialized")
|
||||||
return cls._instance
|
return cls._instance.db
|
||||||
|
|
||||||
|
|
||||||
|
#测试用
|
||||||
|
|
||||||
|
def get_random_group_messages(self, group_id: str, limit: int = 5):
|
||||||
|
# 先随机获取一条消息
|
||||||
|
random_message = list(self.db.messages.aggregate([
|
||||||
|
{"$match": {"group_id": group_id}},
|
||||||
|
{"$sample": {"size": 1}}
|
||||||
|
]))[0]
|
||||||
|
|
||||||
|
# 获取该消息之后的消息
|
||||||
|
subsequent_messages = list(self.db.messages.find({
|
||||||
|
"group_id": group_id,
|
||||||
|
"time": {"$gt": random_message["time"]}
|
||||||
|
}).sort("time", 1).limit(limit))
|
||||||
|
|
||||||
|
# 将随机消息和后续消息合并
|
||||||
|
messages = [random_message] + subsequent_messages
|
||||||
|
|
||||||
|
return messages
|
||||||
@@ -46,7 +46,7 @@ class ReasoningGUI:
|
|||||||
|
|
||||||
# 初始化数据库连接
|
# 初始化数据库连接
|
||||||
try:
|
try:
|
||||||
self.db = Database.get_instance().db
|
self.db = Database.get_instance()
|
||||||
logger.success("数据库连接成功")
|
logger.success("数据库连接成功")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logger.warning("数据库未初始化,正在尝试初始化...")
|
logger.warning("数据库未初始化,正在尝试初始化...")
|
||||||
@@ -60,7 +60,7 @@ class ReasoningGUI:
|
|||||||
password=os.getenv("MONGODB_PASSWORD"),
|
password=os.getenv("MONGODB_PASSWORD"),
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
||||||
)
|
)
|
||||||
self.db = Database.get_instance().db
|
self.db = Database.get_instance()
|
||||||
logger.success("数据库初始化成功")
|
logger.success("数据库初始化成功")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("数据库初始化失败")
|
logger.exception("数据库初始化失败")
|
||||||
|
|||||||
@@ -32,18 +32,6 @@ _message_manager_started = False
|
|||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
Database.initialize(
|
|
||||||
uri=os.getenv("MONGODB_URI"),
|
|
||||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
|
||||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
|
||||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
|
||||||
username=os.getenv("MONGODB_USERNAME"),
|
|
||||||
password=os.getenv("MONGODB_PASSWORD"),
|
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
|
||||||
)
|
|
||||||
logger.success("初始化数据库成功")
|
|
||||||
|
|
||||||
|
|
||||||
# 初始化表情管理器
|
# 初始化表情管理器
|
||||||
emoji_manager.initialize()
|
emoji_manager.initialize()
|
||||||
|
|
||||||
|
|||||||
@@ -111,11 +111,11 @@ class ChatManager:
|
|||||||
|
|
||||||
def _ensure_collection(self):
|
def _ensure_collection(self):
|
||||||
"""确保数据库集合存在并创建索引"""
|
"""确保数据库集合存在并创建索引"""
|
||||||
if "chat_streams" not in self.db.db.list_collection_names():
|
if "chat_streams" not in self.db.list_collection_names():
|
||||||
self.db.db.create_collection("chat_streams")
|
self.db.create_collection("chat_streams")
|
||||||
# 创建索引
|
# 创建索引
|
||||||
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
|
self.db.chat_streams.create_index([("stream_id", 1)], unique=True)
|
||||||
self.db.db.chat_streams.create_index(
|
self.db.chat_streams.create_index(
|
||||||
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
|
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -168,7 +168,7 @@ class ChatManager:
|
|||||||
return stream
|
return stream
|
||||||
|
|
||||||
# 检查数据库中是否存在
|
# 检查数据库中是否存在
|
||||||
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
|
data = self.db.chat_streams.find_one({"stream_id": stream_id})
|
||||||
if data:
|
if data:
|
||||||
stream = ChatStream.from_dict(data)
|
stream = ChatStream.from_dict(data)
|
||||||
# 更新用户信息和群组信息
|
# 更新用户信息和群组信息
|
||||||
@@ -204,7 +204,7 @@ class ChatManager:
|
|||||||
async def _save_stream(self, stream: ChatStream):
|
async def _save_stream(self, stream: ChatStream):
|
||||||
"""保存聊天流到数据库"""
|
"""保存聊天流到数据库"""
|
||||||
if not stream.saved:
|
if not stream.saved:
|
||||||
self.db.db.chat_streams.update_one(
|
self.db.chat_streams.update_one(
|
||||||
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
|
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
|
||||||
)
|
)
|
||||||
stream.saved = True
|
stream.saved = True
|
||||||
@@ -216,7 +216,7 @@ class ChatManager:
|
|||||||
|
|
||||||
async def load_all_streams(self):
|
async def load_all_streams(self):
|
||||||
"""从数据库加载所有聊天流"""
|
"""从数据库加载所有聊天流"""
|
||||||
all_streams = self.db.db.chat_streams.find({})
|
all_streams = self.db.chat_streams.find({})
|
||||||
for data in all_streams:
|
for data in all_streams:
|
||||||
stream = ChatStream.from_dict(data)
|
stream = ChatStream.from_dict(data)
|
||||||
self.streams[stream.stream_id] = stream
|
self.streams[stream.stream_id] = stream
|
||||||
|
|||||||
@@ -76,16 +76,16 @@ class EmojiManager:
|
|||||||
|
|
||||||
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
|
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
|
||||||
"""
|
"""
|
||||||
if 'emoji' not in self.db.db.list_collection_names():
|
if 'emoji' not in self.db.list_collection_names():
|
||||||
self.db.db.create_collection('emoji')
|
self.db.create_collection('emoji')
|
||||||
self.db.db.emoji.create_index([('embedding', '2dsphere')])
|
self.db.emoji.create_index([('embedding', '2dsphere')])
|
||||||
self.db.db.emoji.create_index([('filename', 1)], unique=True)
|
self.db.emoji.create_index([('filename', 1)], unique=True)
|
||||||
|
|
||||||
def record_usage(self, emoji_id: str):
|
def record_usage(self, emoji_id: str):
|
||||||
"""记录表情使用次数"""
|
"""记录表情使用次数"""
|
||||||
try:
|
try:
|
||||||
self._ensure_db()
|
self._ensure_db()
|
||||||
self.db.db.emoji.update_one(
|
self.db.emoji.update_one(
|
||||||
{'_id': emoji_id},
|
{'_id': emoji_id},
|
||||||
{'$inc': {'usage_count': 1}}
|
{'$inc': {'usage_count': 1}}
|
||||||
)
|
)
|
||||||
@@ -119,7 +119,7 @@ class EmojiManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取所有表情包
|
# 获取所有表情包
|
||||||
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
|
all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
|
||||||
|
|
||||||
if not all_emojis:
|
if not all_emojis:
|
||||||
logger.warning("数据库中没有任何表情包")
|
logger.warning("数据库中没有任何表情包")
|
||||||
@@ -157,7 +157,7 @@ class EmojiManager:
|
|||||||
|
|
||||||
if selected_emoji and 'path' in selected_emoji:
|
if selected_emoji and 'path' in selected_emoji:
|
||||||
# 更新使用次数
|
# 更新使用次数
|
||||||
self.db.db.emoji.update_one(
|
self.db.emoji.update_one(
|
||||||
{'_id': selected_emoji['_id']},
|
{'_id': selected_emoji['_id']},
|
||||||
{'$inc': {'usage_count': 1}}
|
{'$inc': {'usage_count': 1}}
|
||||||
)
|
)
|
||||||
@@ -239,7 +239,7 @@ class EmojiManager:
|
|||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
|
||||||
# 检查是否已经注册过
|
# 检查是否已经注册过
|
||||||
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
|
existing_emoji = self.db['emoji'].find_one({'filename': filename})
|
||||||
description = None
|
description = None
|
||||||
|
|
||||||
if existing_emoji:
|
if existing_emoji:
|
||||||
@@ -305,7 +305,7 @@ class EmojiManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 保存到emoji数据库
|
# 保存到emoji数据库
|
||||||
self.db.db['emoji'].insert_one(emoji_record)
|
self.db['emoji'].insert_one(emoji_record)
|
||||||
logger.success(f"注册新表情包: {filename}")
|
logger.success(f"注册新表情包: {filename}")
|
||||||
logger.info(f"描述: {description}")
|
logger.info(f"描述: {description}")
|
||||||
|
|
||||||
@@ -346,7 +346,7 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
self._ensure_db()
|
self._ensure_db()
|
||||||
# 获取所有表情包记录
|
# 获取所有表情包记录
|
||||||
all_emojis = list(self.db.db.emoji.find())
|
all_emojis = list(self.db.emoji.find())
|
||||||
removed_count = 0
|
removed_count = 0
|
||||||
total_count = len(all_emojis)
|
total_count = len(all_emojis)
|
||||||
|
|
||||||
@@ -354,13 +354,13 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
if 'path' not in emoji:
|
if 'path' not in emoji:
|
||||||
logger.warning(f"发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}")
|
logger.warning(f"发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}")
|
||||||
self.db.db.emoji.delete_one({'_id': emoji['_id']})
|
self.db.emoji.delete_one({'_id': emoji['_id']})
|
||||||
removed_count += 1
|
removed_count += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if 'embedding' not in emoji:
|
if 'embedding' not in emoji:
|
||||||
logger.warning(f"发现过时记录(缺少embedding字段),ID: {emoji.get('_id', 'unknown')}")
|
logger.warning(f"发现过时记录(缺少embedding字段),ID: {emoji.get('_id', 'unknown')}")
|
||||||
self.db.db.emoji.delete_one({'_id': emoji['_id']})
|
self.db.emoji.delete_one({'_id': emoji['_id']})
|
||||||
removed_count += 1
|
removed_count += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -368,7 +368,7 @@ class EmojiManager:
|
|||||||
if not os.path.exists(emoji['path']):
|
if not os.path.exists(emoji['path']):
|
||||||
logger.warning(f"表情包文件已被删除: {emoji['path']}")
|
logger.warning(f"表情包文件已被删除: {emoji['path']}")
|
||||||
# 从数据库中删除记录
|
# 从数据库中删除记录
|
||||||
result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
|
result = self.db.emoji.delete_one({'_id': emoji['_id']})
|
||||||
if result.deleted_count > 0:
|
if result.deleted_count > 0:
|
||||||
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
|
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
|
||||||
removed_count += 1
|
removed_count += 1
|
||||||
@@ -379,7 +379,7 @@ class EmojiManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 验证清理结果
|
# 验证清理结果
|
||||||
remaining_count = self.db.db.emoji.count_documents({})
|
remaining_count = self.db.emoji.count_documents({})
|
||||||
if removed_count > 0:
|
if removed_count > 0:
|
||||||
logger.success(f"已清理 {removed_count} 个失效的表情包记录")
|
logger.success(f"已清理 {removed_count} 个失效的表情包记录")
|
||||||
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")
|
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ class ResponseGenerator:
|
|||||||
reasoning_content: str,
|
reasoning_content: str,
|
||||||
):
|
):
|
||||||
"""保存对话记录到数据库"""
|
"""保存对话记录到数据库"""
|
||||||
self.db.db.reasoning_logs.insert_one(
|
self.db.reasoning_logs.insert_one(
|
||||||
{
|
{
|
||||||
"time": time.time(),
|
"time": time.time(),
|
||||||
"chat_id": message.chat_stream.stream_id,
|
"chat_id": message.chat_stream.stream_id,
|
||||||
|
|||||||
@@ -311,7 +311,7 @@ class PromptBuilder:
|
|||||||
{"$project": {"content": 1, "similarity": 1}}
|
{"$project": {"content": 1, "similarity": 1}}
|
||||||
]
|
]
|
||||||
|
|
||||||
results = list(self.db.db.knowledges.aggregate(pipeline))
|
results = list(self.db.knowledges.aggregate(pipeline))
|
||||||
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
|
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ class RelationshipManager:
|
|||||||
async def load_all_relationships(self):
|
async def load_all_relationships(self):
|
||||||
"""加载所有关系对象"""
|
"""加载所有关系对象"""
|
||||||
db = Database.get_instance()
|
db = Database.get_instance()
|
||||||
all_relationships = db.db.relationships.find({})
|
all_relationships = db.relationships.find({})
|
||||||
for data in all_relationships:
|
for data in all_relationships:
|
||||||
await self.load_relationship(data)
|
await self.load_relationship(data)
|
||||||
|
|
||||||
@@ -176,7 +176,7 @@ class RelationshipManager:
|
|||||||
"""每5分钟自动保存一次关系数据"""
|
"""每5分钟自动保存一次关系数据"""
|
||||||
db = Database.get_instance()
|
db = Database.get_instance()
|
||||||
# 获取所有关系记录
|
# 获取所有关系记录
|
||||||
all_relationships = db.db.relationships.find({})
|
all_relationships = db.relationships.find({})
|
||||||
# 依次加载每条记录
|
# 依次加载每条记录
|
||||||
for data in all_relationships:
|
for data in all_relationships:
|
||||||
await self.load_relationship(data)
|
await self.load_relationship(data)
|
||||||
@@ -206,7 +206,7 @@ class RelationshipManager:
|
|||||||
saved = relationship.saved
|
saved = relationship.saved
|
||||||
|
|
||||||
db = Database.get_instance()
|
db = Database.get_instance()
|
||||||
db.db.relationships.update_one(
|
db.relationships.update_one(
|
||||||
{'user_id': user_id, 'platform': platform},
|
{'user_id': user_id, 'platform': platform},
|
||||||
{'$set': {
|
{'$set': {
|
||||||
'platform': platform,
|
'platform': platform,
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ class MessageStorage:
|
|||||||
"detailed_plain_text": message.detailed_plain_text,
|
"detailed_plain_text": message.detailed_plain_text,
|
||||||
"topic": topic,
|
"topic": topic,
|
||||||
}
|
}
|
||||||
self.db.db.messages.insert_one(message_data)
|
self.db.messages.insert_one(message_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("存储消息失败")
|
logger.exception("存储消息失败")
|
||||||
|
|
||||||
|
|||||||
@@ -40,20 +40,20 @@ class ImageManager:
|
|||||||
|
|
||||||
def _ensure_image_collection(self):
|
def _ensure_image_collection(self):
|
||||||
"""确保images集合存在并创建索引"""
|
"""确保images集合存在并创建索引"""
|
||||||
if 'images' not in self.db.db.list_collection_names():
|
if 'images' not in self.db.list_collection_names():
|
||||||
self.db.db.create_collection('images')
|
self.db.create_collection('images')
|
||||||
# 创建索引
|
# 创建索引
|
||||||
self.db.db.images.create_index([('hash', 1)], unique=True)
|
self.db.images.create_index([('hash', 1)], unique=True)
|
||||||
self.db.db.images.create_index([('url', 1)])
|
self.db.images.create_index([('url', 1)])
|
||||||
self.db.db.images.create_index([('path', 1)])
|
self.db.images.create_index([('path', 1)])
|
||||||
|
|
||||||
def _ensure_description_collection(self):
|
def _ensure_description_collection(self):
|
||||||
"""确保image_descriptions集合存在并创建索引"""
|
"""确保image_descriptions集合存在并创建索引"""
|
||||||
if 'image_descriptions' not in self.db.db.list_collection_names():
|
if 'image_descriptions' not in self.db.list_collection_names():
|
||||||
self.db.db.create_collection('image_descriptions')
|
self.db.create_collection('image_descriptions')
|
||||||
# 创建索引
|
# 创建索引
|
||||||
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
|
self.db.image_descriptions.create_index([('hash', 1)], unique=True)
|
||||||
self.db.db.image_descriptions.create_index([('type', 1)])
|
self.db.image_descriptions.create_index([('type', 1)])
|
||||||
|
|
||||||
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
|
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
|
||||||
"""从数据库获取图片描述
|
"""从数据库获取图片描述
|
||||||
@@ -65,7 +65,7 @@ class ImageManager:
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 描述文本,如果不存在则返回None
|
Optional[str]: 描述文本,如果不存在则返回None
|
||||||
"""
|
"""
|
||||||
result= self.db.db.image_descriptions.find_one({
|
result= self.db.image_descriptions.find_one({
|
||||||
'hash': image_hash,
|
'hash': image_hash,
|
||||||
'type': description_type
|
'type': description_type
|
||||||
})
|
})
|
||||||
@@ -79,7 +79,7 @@ class ImageManager:
|
|||||||
description: 描述文本
|
description: 描述文本
|
||||||
description_type: 描述类型 ('emoji' 或 'image')
|
description_type: 描述类型 ('emoji' 或 'image')
|
||||||
"""
|
"""
|
||||||
self.db.db.image_descriptions.update_one(
|
self.db.image_descriptions.update_one(
|
||||||
{'hash': image_hash, 'type': description_type},
|
{'hash': image_hash, 'type': description_type},
|
||||||
{
|
{
|
||||||
'$set': {
|
'$set': {
|
||||||
@@ -121,7 +121,7 @@ class ImageManager:
|
|||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
|
||||||
# 查重
|
# 查重
|
||||||
existing = self.db.db.images.find_one({'hash': image_hash})
|
existing = self.db.images.find_one({'hash': image_hash})
|
||||||
if existing:
|
if existing:
|
||||||
return existing['path']
|
return existing['path']
|
||||||
|
|
||||||
@@ -142,7 +142,7 @@ class ImageManager:
|
|||||||
'description': description,
|
'description': description,
|
||||||
'timestamp': timestamp
|
'timestamp': timestamp
|
||||||
}
|
}
|
||||||
self.db.db.images.insert_one(image_doc)
|
self.db.images.insert_one(image_doc)
|
||||||
|
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
@@ -159,7 +159,7 @@ class ImageManager:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 先查找是否已存在
|
# 先查找是否已存在
|
||||||
existing = self.db.db.images.find_one({'url': url})
|
existing = self.db.images.find_one({'url': url})
|
||||||
if existing:
|
if existing:
|
||||||
return existing['path']
|
return existing['path']
|
||||||
|
|
||||||
@@ -203,7 +203,7 @@ class ImageManager:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否存在
|
bool: 是否存在
|
||||||
"""
|
"""
|
||||||
return self.db.db.images.find_one({'url': url}) is not None
|
return self.db.images.find_one({'url': url}) is not None
|
||||||
|
|
||||||
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
|
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
|
||||||
"""检查图像是否已存在
|
"""检查图像是否已存在
|
||||||
@@ -226,7 +226,7 @@ class ImageManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
return self.db.db.images.find_one({'hash': image_hash}) is not None
|
return self.db.images.find_one({'hash': image_hash}) is not None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"检查哈希失败: {str(e)}")
|
logger.error(f"检查哈希失败: {str(e)}")
|
||||||
@@ -269,7 +269,7 @@ class ImageManager:
|
|||||||
'description': description,
|
'description': description,
|
||||||
'timestamp': timestamp
|
'timestamp': timestamp
|
||||||
}
|
}
|
||||||
self.db.db.images.update_one(
|
self.db.images.update_one(
|
||||||
{'hash': image_hash},
|
{'hash': image_hash},
|
||||||
{'$set': image_doc},
|
{'$set': image_doc},
|
||||||
upsert=True
|
upsert=True
|
||||||
@@ -326,7 +326,7 @@ class ImageManager:
|
|||||||
'description': description,
|
'description': description,
|
||||||
'timestamp': timestamp
|
'timestamp': timestamp
|
||||||
}
|
}
|
||||||
self.db.db.images.update_one(
|
self.db.images.update_one(
|
||||||
{'hash': image_hash},
|
{'hash': image_hash},
|
||||||
{'$set': image_doc},
|
{'$set': image_doc},
|
||||||
upsert=True
|
upsert=True
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class Memory_graph:
|
|||||||
dot_data = {
|
dot_data = {
|
||||||
"concept": node
|
"concept": node
|
||||||
}
|
}
|
||||||
self.db.db.store_memory_dots.insert_one(dot_data)
|
self.db.store_memory_dots.insert_one(dot_data)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dots(self):
|
def dots(self):
|
||||||
@@ -106,7 +106,7 @@ class Memory_graph:
|
|||||||
def get_random_chat_from_db(self, length: int, timestamp: str):
|
def get_random_chat_from_db(self, length: int, timestamp: str):
|
||||||
# 从数据库中根据时间戳获取离其最近的聊天记录
|
# 从数据库中根据时间戳获取离其最近的聊天记录
|
||||||
chat_text = ''
|
chat_text = ''
|
||||||
closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
|
||||||
logger.info(
|
logger.info(
|
||||||
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
|
||||||
|
|
||||||
@@ -115,7 +115,7 @@ class Memory_graph:
|
|||||||
group_id = closest_record['group_id'] # 获取groupid
|
group_id = closest_record['group_id'] # 获取groupid
|
||||||
# 获取该时间戳之后的length条消息,且groupid相同
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
chat_record = list(
|
chat_record = list(
|
||||||
self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
|
self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
|
||||||
length))
|
length))
|
||||||
for record in chat_record:
|
for record in chat_record:
|
||||||
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
|
||||||
@@ -130,34 +130,34 @@ class Memory_graph:
|
|||||||
|
|
||||||
def save_graph_to_db(self):
|
def save_graph_to_db(self):
|
||||||
# 清空现有的图数据
|
# 清空现有的图数据
|
||||||
self.db.db.graph_data.delete_many({})
|
self.db.graph_data.delete_many({})
|
||||||
# 保存节点
|
# 保存节点
|
||||||
for node in self.G.nodes(data=True):
|
for node in self.G.nodes(data=True):
|
||||||
node_data = {
|
node_data = {
|
||||||
'concept': node[0],
|
'concept': node[0],
|
||||||
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
'memory_items': node[1].get('memory_items', []) # 默认为空列表
|
||||||
}
|
}
|
||||||
self.db.db.graph_data.nodes.insert_one(node_data)
|
self.db.graph_data.nodes.insert_one(node_data)
|
||||||
# 保存边
|
# 保存边
|
||||||
for edge in self.G.edges():
|
for edge in self.G.edges():
|
||||||
edge_data = {
|
edge_data = {
|
||||||
'source': edge[0],
|
'source': edge[0],
|
||||||
'target': edge[1]
|
'target': edge[1]
|
||||||
}
|
}
|
||||||
self.db.db.graph_data.edges.insert_one(edge_data)
|
self.db.graph_data.edges.insert_one(edge_data)
|
||||||
|
|
||||||
def load_graph_from_db(self):
|
def load_graph_from_db(self):
|
||||||
# 清空当前图
|
# 清空当前图
|
||||||
self.G.clear()
|
self.G.clear()
|
||||||
# 加载节点
|
# 加载节点
|
||||||
nodes = self.db.db.graph_data.nodes.find()
|
nodes = self.db.graph_data.nodes.find()
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
memory_items = node.get('memory_items', [])
|
memory_items = node.get('memory_items', [])
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
self.G.add_node(node['concept'], memory_items=memory_items)
|
self.G.add_node(node['concept'], memory_items=memory_items)
|
||||||
# 加载边
|
# 加载边
|
||||||
edges = self.db.db.graph_data.edges.find()
|
edges = self.db.graph_data.edges.find()
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
self.G.add_edge(edge['source'], edge['target'])
|
self.G.add_edge(edge['source'], edge['target'])
|
||||||
|
|
||||||
|
|||||||
@@ -892,15 +892,6 @@ config = driver.config
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
Database.initialize(
|
|
||||||
uri=os.getenv("MONGODB_URI"),
|
|
||||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
|
||||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
|
||||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
|
||||||
username=os.getenv("MONGODB_USERNAME"),
|
|
||||||
password=os.getenv("MONGODB_PASSWORD"),
|
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
|
||||||
)
|
|
||||||
# 创建记忆图
|
# 创建记忆图
|
||||||
memory_graph = Memory_graph()
|
memory_graph = Memory_graph()
|
||||||
# 创建海马体
|
# 创建海马体
|
||||||
|
|||||||
@@ -41,10 +41,10 @@ class LLM_request:
|
|||||||
"""初始化数据库集合"""
|
"""初始化数据库集合"""
|
||||||
try:
|
try:
|
||||||
# 创建llm_usage集合的索引
|
# 创建llm_usage集合的索引
|
||||||
self.db.db.llm_usage.create_index([("timestamp", 1)])
|
self.db.llm_usage.create_index([("timestamp", 1)])
|
||||||
self.db.db.llm_usage.create_index([("model_name", 1)])
|
self.db.llm_usage.create_index([("model_name", 1)])
|
||||||
self.db.db.llm_usage.create_index([("user_id", 1)])
|
self.db.llm_usage.create_index([("user_id", 1)])
|
||||||
self.db.db.llm_usage.create_index([("request_type", 1)])
|
self.db.llm_usage.create_index([("request_type", 1)])
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("创建数据库索引失败")
|
logger.error("创建数据库索引失败")
|
||||||
|
|
||||||
@@ -73,7 +73,7 @@ class LLM_request:
|
|||||||
"status": "success",
|
"status": "success",
|
||||||
"timestamp": datetime.now()
|
"timestamp": datetime.now()
|
||||||
}
|
}
|
||||||
self.db.db.llm_usage.insert_one(usage_data)
|
self.db.llm_usage.insert_one(usage_data)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Token使用情况 - 模型: {self.model_name}, "
|
f"Token使用情况 - 模型: {self.model_name}, "
|
||||||
f"用户: {user_id}, 类型: {request_type}, "
|
f"用户: {user_id}, 类型: {request_type}, "
|
||||||
|
|||||||
@@ -14,16 +14,6 @@ from ..models.utils_model import LLM_request
|
|||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
Database.initialize(
|
|
||||||
uri=os.getenv("MONGODB_URI"),
|
|
||||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
|
||||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
|
||||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
|
||||||
username=os.getenv("MONGODB_USERNAME"),
|
|
||||||
password=os.getenv("MONGODB_PASSWORD"),
|
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
|
||||||
)
|
|
||||||
|
|
||||||
class ScheduleGenerator:
|
class ScheduleGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 根据global_config.llm_normal这一字典配置指定模型
|
# 根据global_config.llm_normal这一字典配置指定模型
|
||||||
@@ -56,7 +46,7 @@ class ScheduleGenerator:
|
|||||||
|
|
||||||
schedule_text = str
|
schedule_text = str
|
||||||
|
|
||||||
existing_schedule = self.db.db.schedule.find_one({"date": date_str})
|
existing_schedule = self.db.schedule.find_one({"date": date_str})
|
||||||
if existing_schedule:
|
if existing_schedule:
|
||||||
logger.debug(f"{date_str}的日程已存在:")
|
logger.debug(f"{date_str}的日程已存在:")
|
||||||
schedule_text = existing_schedule["schedule"]
|
schedule_text = existing_schedule["schedule"]
|
||||||
@@ -73,7 +63,7 @@ class ScheduleGenerator:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
|
schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
|
||||||
self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
|
self.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成日程失败: {str(e)}")
|
logger.error(f"生成日程失败: {str(e)}")
|
||||||
schedule_text = "生成日程时出错了"
|
schedule_text = "生成日程时出错了"
|
||||||
@@ -153,7 +143,7 @@ class ScheduleGenerator:
|
|||||||
"""打印完整的日程安排"""
|
"""打印完整的日程安排"""
|
||||||
if not self._parse_schedule(self.today_schedule_text):
|
if not self._parse_schedule(self.today_schedule_text):
|
||||||
logger.warning("今日日程有误,将在下次运行时重新生成")
|
logger.warning("今日日程有误,将在下次运行时重新生成")
|
||||||
self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
|
self.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
|
||||||
else:
|
else:
|
||||||
logger.info("=== 今日日程安排 ===")
|
logger.info("=== 今日日程安排 ===")
|
||||||
for time_str, activity in self.today_schedule.items():
|
for time_str, activity in self.today_schedule.items():
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class LLMStatistics:
|
|||||||
"costs_by_model": defaultdict(float)
|
"costs_by_model": defaultdict(float)
|
||||||
}
|
}
|
||||||
|
|
||||||
cursor = self.db.db.llm_usage.find({
|
cursor = self.db.llm_usage.find({
|
||||||
"timestamp": {"$gte": start_time}
|
"timestamp": {"$gte": start_time}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user