优化了现有的知识库系统

This commit is contained in:
Voyager1
2025-04-05 17:31:34 +08:00
parent 0ed022f874
commit db14d9c39b
2 changed files with 173 additions and 73 deletions

View File

@@ -1,16 +1,19 @@
import random import random
import time import time
from typing import Optional from typing import Optional, Union
import re
import jieba
import numpy as np
from ....common.database import db from ....common.database import db
from ...memory_system.Hippocampus import HippocampusManager
from ...moods.moods import MoodManager
from ...schedule.schedule_generator import bot_schedule
from ...config.config import global_config
from ...chat.utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker from ...chat.utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker
from ...chat.chat_stream import chat_manager from ...chat.chat_stream import chat_manager
from src.common.logger import get_module_logger from ...moods.moods import MoodManager
from ...memory_system.Hippocampus import HippocampusManager
from ...schedule.schedule_generator import bot_schedule
from ...config.config import global_config
from ...person_info.relationship_manager import relationship_manager from ...person_info.relationship_manager import relationship_manager
from src.common.logger import get_module_logger
logger = get_module_logger("prompt") logger = get_module_logger("prompt")
@@ -128,7 +131,7 @@ class PromptBuilder:
# 知识构建 # 知识构建
start_time = time.time() start_time = time.time()
prompt_info = "" prompt_info = ""
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5) prompt_info = await self.get_prompt_info(message_txt, threshold=0.38)
if prompt_info: if prompt_info:
prompt_info = f"""\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n""" prompt_info = f"""\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n"""
@@ -158,16 +161,156 @@ class PromptBuilder:
return prompt return prompt
async def get_prompt_info(self, message: str, threshold: float): async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time()
related_info = "" related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message, request_type="prompt_build")
related_info += self.get_info_from_db(embedding, limit=1, threshold=threshold) # 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
try:
# 先尝试使用记忆系统的方法获取主题
hippocampus = HippocampusManager.get_instance()._hippocampus
topic_num = min(5, max(1, int(len(message) * 0.1)))
topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# 提取关键词
topics = re.findall(r"<([^>]+)>", topics_response[0])
if not topics:
topics = []
else:
topics = [
topic.strip()
for topic in ",".join(topics).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if topic.strip()
]
logger.info(f"从LLM提取的主题: {', '.join(topics)}")
except Exception as e:
logger.error(f"从LLM提取主题失败: {str(e)}")
# 如果LLM提取失败使用jieba分词提取关键词作为备选
words = jieba.cut(message)
topics = [word for word in words if len(word) > 1][:5]
logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.info("未能提取到任何主题,使用整个消息进行查询")
embedding = await get_embedding(message, request_type="prompt_build")
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="prompt_build")
if embedding:
embeddings[text] = embedding
else:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
if original_results:
for result in original_results:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
# 添加主题标记
for result in topic_results:
result["topic"] = topic
all_results.extend(topic_results)
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
filtered_results = []
for result in all_results:
content = result["content"]
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
grouped_results = {}
for result in filtered_results:
topic = result["topic"]
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
for i, result in enumerate(results, 1):
similarity = result["similarity"]
content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str: def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
if not query_embedding: if not query_embedding:
return "" return "" if not return_raw else []
# 使用余弦相似度计算 # 使用余弦相似度计算
pipeline = [ pipeline = [
{ {
@@ -221,13 +364,16 @@ class PromptBuilder:
] ]
results = list(db.knowledges.aggregate(pipeline)) results = list(db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") logger.debug(f"知识库查询结果数量: {len(results)}")
if not results: if not results:
return "" return "" if not return_raw else []
# 返回所有找到的内容,用换行分隔 if return_raw:
return "\n".join(str(result["content"]) for result in results) return results
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
prompt_builder = PromptBuilder() prompt_builder = PromptBuilder()

View File

@@ -41,7 +41,7 @@ class KnowledgeLibrary:
return f.read() return f.read()
def split_content(self, content: str, max_length: int = 512) -> list: def split_content(self, content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,保持段落完整性 """将内容分割成适当大小的块,按空行分割
Args: Args:
content: 要分割的文本内容 content: 要分割的文本内容
@@ -50,67 +50,21 @@ class KnowledgeLibrary:
Returns: Returns:
list: 分割后的文本块列表 list: 分割后的文本块列表
""" """
# 首先按段落分割 # 按空行分割内容
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()] paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
chunks = [] chunks = []
current_chunk = []
current_length = 0
for para in paragraphs: for para in paragraphs:
para_length = len(para) para_length = len(para)
# 如果单个段落就超过最大长度 # 如果段落长度小于等于最大长度,直接添加
if para_length > max_length: if para_length <= max_length:
# 如果当前chunk不为空先保存 chunks.append(para)
if current_chunk:
chunks.append("\n".join(current_chunk))
current_chunk = []
current_length = 0
# 将长段落按句子分割
sentences = [
s.strip()
for s in para.replace("", "\n").replace("", "\n").replace("", "\n").split("\n")
if s.strip()
]
temp_chunk = []
temp_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if sentence_length > max_length:
# 如果单个句子超长,强制按长度分割
if temp_chunk:
chunks.append("\n".join(temp_chunk))
temp_chunk = []
temp_length = 0
for i in range(0, len(sentence), max_length):
chunks.append(sentence[i : i + max_length])
elif temp_length + sentence_length + 1 <= max_length:
temp_chunk.append(sentence)
temp_length += sentence_length + 1
else:
chunks.append("\n".join(temp_chunk))
temp_chunk = [sentence]
temp_length = sentence_length
if temp_chunk:
chunks.append("\n".join(temp_chunk))
# 如果当前段落加上现有chunk不超过最大长度
elif current_length + para_length + 1 <= max_length:
current_chunk.append(para)
current_length += para_length + 1
else: else:
# 保存当前chunk并开始新的chunk # 如果段落超过最大长度,则按最大长度切分
chunks.append("\n".join(current_chunk)) for i in range(0, para_length, max_length):
current_chunk = [para] chunks.append(para[i:i + max_length])
current_length = para_length
# 添加最后一个chunk
if current_chunk:
chunks.append("\n".join(current_chunk))
return chunks return chunks
def get_embedding(self, text: str) -> list: def get_embedding(self, text: str) -> list: