import os import sys import numpy as np import requests import time from nonebot import get_driver driver = get_driver() config = driver.config # 添加项目根目录到 Python 路径 root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) sys.path.append(root_path) from src.common.database import Database from src.plugins.chat.config import llm_config # 直接配置数据库连接信息 Database.initialize( host= config.mongodb_host, port= int(config.mongodb_port), db_name= config.database_name, username= config.mongodb_username, password= config.mongodb_password, auth_source=config.mongodb_auth_source ) class KnowledgeLibrary: def __init__(self): self.db = Database.get_instance() self.raw_info_dir = "data/raw_info" self._ensure_dirs() def _ensure_dirs(self): """确保必要的目录存在""" os.makedirs(self.raw_info_dir, exist_ok=True) def get_embedding(self, text: str) -> list: """获取文本的embedding向量""" url = "https://api.siliconflow.cn/v1/embeddings" payload = { "model": "BAAI/bge-m3", "input": text, "encoding_format": "float" } headers = { "Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}", "Content-Type": "application/json" } response = requests.post(url, json=payload, headers=headers) if response.status_code != 200: print(f"获取embedding失败: {response.text}") return None return response.json()['data'][0]['embedding'] def process_files(self): """处理raw_info目录下的所有txt文件""" for filename in os.listdir(self.raw_info_dir): if filename.endswith('.txt'): file_path = os.path.join(self.raw_info_dir, filename) self.process_single_file(file_path) def process_single_file(self, file_path: str): """处理单个文件""" try: # 检查文件是否已处理 if self.db.db.processed_files.find_one({"file_path": file_path}): print(f"文件已处理过,跳过: {file_path}") return with open(file_path, 'r', encoding='utf-8') as f: content = f.read() # 按1024字符分段 segments = [content[i:i+300] for i in range(0, len(content), 300)] # 处理每个分段 for segment in segments: if not segment.strip(): # 跳过空段 continue # 获取embedding embedding = self.get_embedding(segment) if not embedding: continue # 存储到数据库 doc = { "content": segment, "embedding": embedding, "file_path": file_path, "segment_length": len(segment) } # 使用文本内容的哈希值作为唯一标识 content_hash = hash(segment) # 更新或插入文档 self.db.db.knowledges.update_one( {"content_hash": content_hash}, {"$set": doc}, upsert=True ) # 记录文件已处理 self.db.db.processed_files.insert_one({ "file_path": file_path, "processed_time": time.time() }) print(f"成功处理文件: {file_path}") except Exception as e: print(f"处理文件 {file_path} 时出错: {str(e)}") def search_similar_segments(self, query: str, limit: int = 5) -> list: """搜索与查询文本相似的片段""" query_embedding = self.get_embedding(query) if not query_embedding: return [] # 使用余弦相似度计算 pipeline = [ { "$addFields": { "dotProduct": { "$reduce": { "input": {"$range": [0, {"$size": "$embedding"}]}, "initialValue": 0, "in": { "$add": [ "$$value", {"$multiply": [ {"$arrayElemAt": ["$embedding", "$$this"]}, {"$arrayElemAt": [query_embedding, "$$this"]} ]} ] } } }, "magnitude1": { "$sqrt": { "$reduce": { "input": "$embedding", "initialValue": 0, "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} } } }, "magnitude2": { "$sqrt": { "$reduce": { "input": query_embedding, "initialValue": 0, "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} } } } } }, { "$addFields": { "similarity": { "$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}] } } }, {"$sort": {"similarity": -1}}, {"$limit": limit}, {"$project": {"content": 1, "similarity": 1, "file_path": 1}} ] results = list(self.db.db.knowledges.aggregate(pipeline)) return results # 创建单例实例 knowledge_library = KnowledgeLibrary() if __name__ == "__main__": # 测试知识库功能 print("开始处理知识库文件...") knowledge_library.process_files() # 测试搜索功能 test_query = "麦麦评价一下僕と花" print(f"\n搜索与'{test_query}'相似的内容:") results = knowledge_library.search_similar_segments(test_query) for result in results: print(f"相似度: {result['similarity']:.4f}") print(f"内容: {result['content'][:100]}...") print("-" * 50)