update: 飞掉老知识库

This commit is contained in:
墨梓柒
2025-05-20 23:03:28 +08:00
parent 25d9032e62
commit afaf24d28c
4 changed files with 6 additions and 528 deletions

View File

@@ -1,5 +1,5 @@
""" """
MaiMBot插件系统 MaiBot模块系统
包含聊天、情绪、记忆、日程等功能模块 包含聊天、情绪、记忆、日程等功能模块
""" """

View File

@@ -4,17 +4,13 @@ from src.individuality.individuality import individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.chat.person_info.relationship_manager import relationship_manager from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.utils import get_embedding
import time import time
from typing import Union, Optional from typing import Optional
from src.chat.utils.utils import get_recent_group_speaker from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.knowledge.knowledge_lib import qa_manager
import random import random
import json
import math
from src.common.database.database_model import Knowledges
logger = get_logger("prompt") logger = get_logger("prompt")
@@ -262,130 +258,6 @@ class PromptBuilder:
# --- End choosing template --- # --- End choosing template ---
return prompt return prompt
async def get_prompt_info_old(self, message: str, threshold: float):
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.info("未能提取到任何主题,使用整个消息进行查询")
embedding = await get_embedding(message, request_type="prompt_build")
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="prompt_build")
if embedding:
embeddings[text] = embedding
else:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
if original_results:
for result in original_results:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
# 添加主题标记
for result in topic_results:
result["topic"] = topic
all_results.extend(topic_results)
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
filtered_results = []
for result in all_results:
content = result["content"]
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
grouped_results = {}
for result in filtered_results:
topic = result["topic"]
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
for _i, result in enumerate(results, 1):
_similarity = result["similarity"]
content = result["content"].strip()
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info
async def get_prompt_info(self, message: str, threshold: float): async def get_prompt_info(self, message: str, threshold: float):
related_info = "" related_info = ""
start_time = time.time() start_time = time.time()
@@ -405,93 +277,11 @@ class PromptBuilder:
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info return related_info
else: else:
logger.debug("从LPMM知识库获取知识失败使用旧版数据库进行检索") logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识,返回空知识...")
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) return "未检索到知识"
related_info += knowledge_from_old
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return related_info
except Exception as e: except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}") logger.error(f"获取知识库内容时发生异常: {str(e)}")
try: return "未检索到知识"
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
related_info += knowledge_from_old
logger.debug(
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
)
return related_info
except Exception as e2:
logger.error(f"使用旧版数据库获取知识时也发生异常: {str(e2)}")
return ""
@staticmethod
def get_info_from_db(
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
results_with_similarity = []
try:
# Fetch all knowledge entries
# This might be inefficient for very large databases.
# Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
all_knowledges = Knowledges.select()
if not all_knowledges:
return [] if return_raw else ""
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
if query_embedding_magnitude == 0: # Avoid division by zero
return "" if not return_raw else []
for knowledge_item in all_knowledges:
try:
db_embedding_str = knowledge_item.embedding
db_embedding = json.loads(db_embedding_str)
if len(db_embedding) != len(query_embedding):
logger.warning(
f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
)
continue
# Calculate Cosine Similarity
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
if db_embedding_magnitude == 0: # Avoid division by zero
similarity = 0.0
else:
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
if similarity >= threshold:
results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
except json.JSONDecodeError:
logger.error(
f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
)
except Exception as e:
logger.error(f"Error processing knowledge item: {e}")
# Sort by similarity in descending order
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
# Limit results
limited_results = results_with_similarity[:limit]
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
if not limited_results:
return "" if not return_raw else []
if return_raw:
return limited_results
else:
return "\n".join(str(result["content"]) for result in limited_results)
except Exception as e:
logger.error(f"Error querying Knowledges with Peewee: {e}")
return "" if not return_raw else []
init_prompt() init_prompt()

View File

@@ -121,5 +121,5 @@ class QAManager:
found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n" found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n"
return found_knowledge return found_knowledge
else: else:
logger.info("LPMM知识库并未初始化使用旧版数据库进行检索") logger.info("LPMM知识库并未初始化可能是从未导入过知识...")
return None return None

View File

@@ -1,312 +0,0 @@
import os
import sys
import requests
from dotenv import load_dotenv
import hashlib
from datetime import datetime
from tqdm import tqdm
from rich.console import Console
from rich.table import Table
from rich.traceback import install
install(extra_lines=3)
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 现在可以导入src模块
from common.database.database import db # noqa E402
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
class KnowledgeLibrary:
def __init__(self):
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
self.console = Console()
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
@staticmethod
def read_file(file_path: str) -> str:
"""读取文件内容"""
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
@staticmethod
def split_content(content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,按空行分割
Args:
content: 要分割的文本内容
max_length: 每个块的最大长度
Returns:
list: 分割后的文本块列表
"""
# 按空行分割内容
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
chunks = []
for para in paragraphs:
para_length = len(para)
# 如果段落长度小于等于最大长度,直接添加
if para_length <= max_length:
chunks.append(para)
else:
# 如果段落超过最大长度,则按最大长度切分
for i in range(0, para_length, max_length):
chunks.append(para[i : i + max_length])
return chunks
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {"model": "BAAI/bge-m3", "input": text, "encoding_format": "float"}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()["data"][0]["embedding"]
def process_files(self, knowledge_length: int = 512):
"""处理raw_info目录下的所有txt文件"""
txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith(".txt")]
if not txt_files:
self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir))
self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]")
return
total_stats = {"processed_files": 0, "total_chunks": 0, "failed_files": [], "skipped_files": []}
self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]")
for filename in tqdm(txt_files, desc="处理文件进度"):
file_path = os.path.join(self.raw_info_dir, filename)
result = self.process_single_file(file_path, knowledge_length)
self._update_stats(total_stats, result, filename)
self._display_processing_results(total_stats)
def process_single_file(self, file_path: str, knowledge_length: int = 512):
"""处理单个文件"""
result = {"status": "success", "chunks_processed": 0, "error": None}
try:
current_hash = self.calculate_file_hash(file_path)
processed_record = db.processed_files.find_one({"file_path": file_path})
if processed_record:
if processed_record.get("hash") == current_hash:
if knowledge_length in processed_record.get("split_by", []):
result["status"] = "skipped"
return result
content = self.read_file(file_path)
chunks = self.split_content(content, knowledge_length)
for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False):
embedding = self.get_embedding(chunk)
if embedding:
knowledge = {
"content": chunk,
"embedding": embedding,
"source_file": file_path,
"split_length": knowledge_length,
"created_at": datetime.now(),
}
db.knowledges.insert_one(knowledge)
result["chunks_processed"] += 1
split_by = processed_record.get("split_by", []) if processed_record else []
if knowledge_length not in split_by:
split_by.append(knowledge_length)
db.knowledges.processed_files.update_one(
{"file_path": file_path},
{"$set": {"hash": current_hash, "last_processed": datetime.now(), "split_by": split_by}},
upsert=True,
)
except Exception as e:
result["status"] = "failed"
result["error"] = str(e)
return result
@staticmethod
def _update_stats(total_stats, result, filename):
"""更新总体统计信息"""
if result["status"] == "success":
total_stats["processed_files"] += 1
total_stats["total_chunks"] += result["chunks_processed"]
elif result["status"] == "failed":
total_stats["failed_files"].append((filename, result["error"]))
elif result["status"] == "skipped":
total_stats["skipped_files"].append(filename)
def _display_processing_results(self, stats):
"""显示处理结果统计"""
self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]")
table = Table(show_header=True, header_style="bold magenta")
table.add_column("统计项", style="dim")
table.add_column("数值")
table.add_row("成功处理文件数", str(stats["processed_files"]))
table.add_row("处理的知识块总数", str(stats["total_chunks"]))
table.add_row("跳过的文件数", str(len(stats["skipped_files"])))
table.add_row("失败的文件数", str(len(stats["failed_files"])))
self.console.print(table)
if stats["failed_files"]:
self.console.print("\n[bold red]处理失败的文件:[/bold red]")
for filename, error in stats["failed_files"]:
self.console.print(f"[red]- {filename}: {error}[/red]")
if stats["skipped_files"]:
self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]")
for filename in stats["skipped_files"]:
self.console.print(f"[yellow]- {filename}[/yellow]")
@staticmethod
def calculate_file_hash(file_path):
"""计算文件的MD5哈希值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}},
]
results = list(db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
console = Console()
console.print("[bold green]知识库处理工具[/bold green]")
while True:
console.print("\n请选择要执行的操作:")
console.print("[1] 麦麦开始学习")
console.print("[2] 麦麦全部忘光光(仅知识)")
console.print("[q] 退出程序")
choice = input("\n请输入选项: ").strip()
if choice.lower() == "q":
console.print("[yellow]程序退出[/yellow]")
sys.exit(0)
elif choice == "2":
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
if confirm == "y":
db.knowledges.delete_many({})
console.print("[green]已清空所有知识![/green]")
continue
elif choice == "1":
if not os.path.exists(knowledge_library.raw_info_dir):
console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]")
os.makedirs(knowledge_library.raw_info_dir, exist_ok=True)
# 询问分割长度
while True:
try:
length_input = input("请输入知识分割长度默认512输入q退出回车使用默认值: ").strip()
if length_input.lower() == "q":
break
if not length_input: # 如果直接回车,使用默认值
knowledge_length = 512
break
knowledge_length = int(length_input)
if knowledge_length <= 0:
print("分割长度必须大于0请重新输入")
continue
break
except ValueError:
print("请输入有效的数字")
continue
if length_input.lower() == "q":
continue
# 测试知识库功能
print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...")
knowledge_library.process_files(knowledge_length=knowledge_length)
else:
console.print("[red]无效的选项,请重新选择[/red]")
continue