Merge remote-tracking branch 'upstream/debug' into debug

This commit is contained in:
tcmofashi
2025-03-12 21:47:53 +08:00
20 changed files with 189 additions and 359 deletions

View File

@@ -7,7 +7,6 @@ from nonebot import get_driver, on_message, require
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment,MessageEvent
from nonebot.typing import T_State
from ...common.database import Database
from ..moods.moods import MoodManager # 导入情绪管理器
from ..schedule.schedule_generator import bot_schedule
from ..utils.statistic import LLMStatistics

View File

@@ -6,7 +6,7 @@ from typing import Dict, Optional
from loguru import logger
from ...common.database import Database
from ...common.database import db
from .message_base import GroupInfo, UserInfo
@@ -83,7 +83,6 @@ class ChatManager:
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.db = Database.get_instance()
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
@@ -111,11 +110,11 @@ class ChatManager:
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.list_collection_names():
self.db.create_collection("chat_streams")
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
# 创建索引
self.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.chat_streams.create_index(
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
@@ -168,7 +167,7 @@ class ChatManager:
return stream
# 检查数据库中是否存在
data = self.db.chat_streams.find_one({"stream_id": stream_id})
data = db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
@@ -204,7 +203,7 @@ class ChatManager:
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.chat_streams.update_one(
db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
@@ -216,7 +215,7 @@ class ChatManager:
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.chat_streams.find({})
all_streams = db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream

View File

@@ -12,7 +12,7 @@ import io
from loguru import logger
from nonebot import get_driver
from ...common.database import Database
from ...common.database import db
from ..chat.config import global_config
from ..chat.utils import get_embedding
from ..chat.utils_image import ImageManager, image_path_to_base64
@@ -30,12 +30,10 @@ class EmojiManager:
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
self.db = Database.get_instance()
self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(
@@ -50,7 +48,6 @@ class EmojiManager:
"""初始化数据库连接和表情目录"""
if not self._initialized:
try:
self.db = Database.get_instance()
self._ensure_emoji_collection()
self._ensure_emoji_dir()
self._initialized = True
@@ -78,16 +75,16 @@ class EmojiManager:
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
"""
if "emoji" not in self.db.list_collection_names():
self.db.create_collection("emoji")
self.db.emoji.create_index([("embedding", "2dsphere")])
self.db.emoji.create_index([("filename", 1)], unique=True)
if "emoji" not in db.list_collection_names():
db.create_collection("emoji")
db.emoji.create_index([("embedding", "2dsphere")])
db.emoji.create_index([("filename", 1)], unique=True)
def record_usage(self, emoji_id: str):
"""记录表情使用次数"""
try:
self._ensure_db()
self.db.emoji.update_one({"_id": emoji_id}, {"$inc": {"usage_count": 1}})
db.emoji.update_one({"_id": emoji_id}, {"$inc": {"usage_count": 1}})
except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}")
@@ -118,7 +115,7 @@ class EmojiManager:
try:
# 获取所有表情包
all_emojis = list(self.db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1}))
all_emojis = list(db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1}))
if not all_emojis:
logger.warning("数据库中没有任何表情包")
@@ -155,7 +152,7 @@ class EmojiManager:
if selected_emoji and "path" in selected_emoji:
# 更新使用次数
self.db.emoji.update_one({"_id": selected_emoji["_id"]}, {"$inc": {"usage_count": 1}})
db.emoji.update_one({"_id": selected_emoji["_id"]}, {"$inc": {"usage_count": 1}})
logger.success(
f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})"
@@ -235,14 +232,14 @@ class EmojiManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 检查是否已经注册过
existing_emoji = self.db["emoji"].find_one({"filename": filename})
existing_emoji = db["emoji"].find_one({"filename": filename})
description = None
if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get("discription")
# 检查是否在images集合中存在
existing_image = image_manager.db.images.find_one({"hash": image_hash})
existing_image = db.images.find_one({"hash": image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
@@ -252,7 +249,7 @@ class EmojiManager:
"description": description,
"timestamp": int(time.time()),
}
image_manager.db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, "emoji")
logger.success(f"同步已存在的表情包到images集合: {filename}")
@@ -295,7 +292,7 @@ class EmojiManager:
}
# 保存到emoji数据库
self.db["emoji"].insert_one(emoji_record)
db["emoji"].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}")
@@ -307,7 +304,7 @@ class EmojiManager:
"description": description,
"timestamp": int(time.time()),
}
image_manager.db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, "emoji")
logger.success(f"同步保存到images集合: {filename}")
@@ -331,7 +328,7 @@ class EmojiManager:
try:
self._ensure_db()
# 获取所有表情包记录
all_emojis = list(self.db.emoji.find())
all_emojis = list(db.emoji.find())
removed_count = 0
total_count = len(all_emojis)
@@ -339,26 +336,26 @@ class EmojiManager:
try:
if "path" not in emoji:
logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
self.db.emoji.delete_one({"_id": emoji["_id"]})
db.emoji.delete_one({"_id": emoji["_id"]})
removed_count += 1
continue
if "embedding" not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.emoji.delete_one({"_id": emoji["_id"]})
db.emoji.delete_one({"_id": emoji["_id"]})
removed_count += 1
continue
if "hash" not in emoji:
logger.warning(f"发现缺失记录缺少hash字段ID: {emoji.get('_id', 'unknown')}")
hash = hashlib.md5(open(emoji["path"], "rb").read()).hexdigest()
self.db.emoji.update_one({"_id": emoji["_id"]}, {"$set": {"hash": hash}})
db.emoji.update_one({"_id": emoji["_id"]}, {"$set": {"hash": hash}})
# 检查文件是否存在
if not os.path.exists(emoji["path"]):
logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录
result = self.db.emoji.delete_one({"_id": emoji["_id"]})
result = db.emoji.delete_one({"_id": emoji["_id"]})
if result.deleted_count > 0:
logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1
@@ -369,7 +366,7 @@ class EmojiManager:
continue
# 验证清理结果
remaining_count = self.db.emoji.count_documents({})
remaining_count = db.emoji.count_documents({})
if removed_count > 0:
logger.success(f"已清理 {removed_count} 个失效的表情包记录")
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")

View File

@@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Union
from nonebot import get_driver
from loguru import logger
from ...common.database import Database
from ...common.database import db
from ..models.utils_model import LLM_request
from .config import global_config
from .message import MessageRecv, MessageThinking, Message
@@ -34,7 +34,6 @@ class ResponseGenerator:
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.db = Database.get_instance()
self.current_model_type = "r1" # 默认使用 R1
async def generate_response(
@@ -154,7 +153,7 @@ class ResponseGenerator:
reasoning_content: str,
):
"""保存对话记录到数据库"""
self.db.reasoning_logs.insert_one(
db.reasoning_logs.insert_one(
{
"time": time.time(),
"chat_id": message.chat_stream.stream_id,
@@ -211,7 +210,6 @@ class ResponseGenerator:
class InitiativeMessageGenerate:
def __init__(self):
self.db = Database.get_instance()
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7)
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7)
self.model_r1_distill = LLM_request(

View File

@@ -3,7 +3,7 @@ import time
from typing import Optional
from loguru import logger
from ...common.database import Database
from ...common.database import db
from ..memory_system.memory import hippocampus, memory_graph
from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule
@@ -16,7 +16,6 @@ class PromptBuilder:
def __init__(self):
self.prompt_built = ''
self.activate_messages = ''
self.db = Database.get_instance()
@@ -76,7 +75,7 @@ class PromptBuilder:
chat_in_group=True
chat_talking_prompt = ''
if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_talking_prompt = get_recent_group_detailed_plain_text(stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
chat_stream=chat_manager.get_stream(stream_id)
if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
@@ -199,7 +198,7 @@ class PromptBuilder:
chat_talking_prompt = ''
if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
chat_talking_prompt = get_recent_group_detailed_plain_text(group_id,
limit=global_config.MAX_CONTEXT_SIZE,
combine=True)
@@ -311,7 +310,7 @@ class PromptBuilder:
{"$project": {"content": 1, "similarity": 1}}
]
results = list(self.db.knowledges.aggregate(pipeline))
results = list(db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results:

View File

@@ -2,7 +2,7 @@ import asyncio
from typing import Optional
from loguru import logger
from ...common.database import Database
from ...common.database import db
from .message_base import UserInfo
from .chat_stream import ChatStream
@@ -167,14 +167,12 @@ class RelationshipManager:
async def load_all_relationships(self):
"""加载所有关系对象"""
db = Database.get_instance()
all_relationships = db.relationships.find({})
for data in all_relationships:
await self.load_relationship(data)
async def _start_relationship_manager(self):
"""每5分钟自动保存一次关系数据"""
db = Database.get_instance()
# 获取所有关系记录
all_relationships = db.relationships.find({})
# 依次加载每条记录
@@ -205,7 +203,6 @@ class RelationshipManager:
age = relationship.age
saved = relationship.saved
db = Database.get_instance()
db.relationships.update_one(
{'user_id': user_id, 'platform': platform},
{'$set': {

View File

@@ -1,15 +1,12 @@
from typing import Optional, Union
from ...common.database import Database
from ...common.database import db
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from loguru import logger
class MessageStorage:
def __init__(self):
self.db = Database.get_instance()
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库"""
try:
@@ -23,7 +20,7 @@ class MessageStorage:
"detailed_plain_text": message.detailed_plain_text,
"topic": topic,
}
self.db.messages.insert_one(message_data)
db.messages.insert_one(message_data)
except Exception:
logger.exception("存储消息失败")

View File

@@ -16,6 +16,7 @@ from .message import MessageRecv,Message
from .message_base import UserInfo
from .chat_stream import ChatStream
from ..moods.moods import MoodManager
from ...common.database import db
driver = get_driver()
config = driver.config
@@ -76,11 +77,10 @@ def calculate_information_content(text):
return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str):
def get_closest_chat_from_db(length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录
Args:
db: 数据库实例
length: 要获取的消息数量
timestamp: 时间戳
@@ -115,11 +115,10 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
return []
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
async def get_recent_group_messages(chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录
Args:
db: Database实例
group_id: 群组ID
limit: 获取消息数量默认12条
@@ -161,7 +160,7 @@ async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
return message_objects
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
def get_recent_group_detailed_plain_text(chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.messages.find(
{"chat_id": chat_stream_id},
{

View File

@@ -10,7 +10,7 @@ import io
from loguru import logger
from nonebot import get_driver
from ...common.database import Database
from ...common.database import db
from ..chat.config import global_config
from ..models.utils_model import LLM_request
@@ -25,13 +25,11 @@ class ImageManager:
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.db = Database.get_instance()
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
@@ -44,20 +42,20 @@ class ImageManager:
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if "images" not in self.db.list_collection_names():
self.db.create_collection("images")
if "images" not in db.list_collection_names():
db.create_collection("images")
# 创建索引
self.db.images.create_index([("hash", 1)], unique=True)
self.db.images.create_index([("url", 1)])
self.db.images.create_index([("path", 1)])
db.images.create_index([("hash", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in self.db.list_collection_names():
self.db.create_collection("image_descriptions")
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
# 创建索引
self.db.image_descriptions.create_index([("hash", 1)], unique=True)
self.db.image_descriptions.create_index([("type", 1)])
db.image_descriptions.create_index([("hash", 1)], unique=True)
db.image_descriptions.create_index([("type", 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -69,9 +67,7 @@ class ImageManager:
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
if image_hash is None:
return
result = self.db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
return result["description"] if result else None
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
@@ -82,9 +78,7 @@ class ImageManager:
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
if image_hash is None:
return
self.db.image_descriptions.update_one(
db.image_descriptions.update_one(
{"hash": image_hash, "type": description_type},
{"$set": {"description": description, "timestamp": int(time.time())}},
upsert=True,
@@ -120,7 +114,7 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查重
existing = self.db.images.find_one({"hash": image_hash})
existing = db.images.find_one({"hash": image_hash})
if existing:
return existing["path"]
@@ -141,7 +135,7 @@ class ImageManager:
"description": description,
"timestamp": timestamp,
}
self.db.images.insert_one(image_doc)
db.images.insert_one(image_doc)
return file_path
@@ -158,7 +152,7 @@ class ImageManager:
"""
try:
# 先查找是否已存在
existing = self.db.images.find_one({"url": url})
existing = db.images.find_one({"url": url})
if existing:
return existing["path"]
@@ -201,7 +195,7 @@ class ImageManager:
Returns:
bool: 是否存在
"""
return self.db.images.find_one({"url": url}) is not None
return db.images.find_one({"url": url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
@@ -224,7 +218,7 @@ class ImageManager:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.images.find_one({"hash": image_hash}) is not None
return db.images.find_one({"hash": image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
@@ -268,7 +262,7 @@ class ImageManager:
"description": description,
"timestamp": timestamp,
}
self.db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.success(f"保存表情包: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
@@ -328,7 +322,7 @@ class ImageManager:
"description": description,
"timestamp": timestamp,
}
self.db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
logger.success(f"保存图片: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")