优化了现有的知识库系统
This commit is contained in:
@@ -1,16 +1,19 @@
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
import re
|
||||||
|
import jieba
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from ....common.database import db
|
from ....common.database import db
|
||||||
from ...memory_system.Hippocampus import HippocampusManager
|
|
||||||
from ...moods.moods import MoodManager
|
|
||||||
from ...schedule.schedule_generator import bot_schedule
|
|
||||||
from ...config.config import global_config
|
|
||||||
from ...chat.utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker
|
from ...chat.utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker
|
||||||
from ...chat.chat_stream import chat_manager
|
from ...chat.chat_stream import chat_manager
|
||||||
from src.common.logger import get_module_logger
|
from ...moods.moods import MoodManager
|
||||||
|
from ...memory_system.Hippocampus import HippocampusManager
|
||||||
|
from ...schedule.schedule_generator import bot_schedule
|
||||||
|
from ...config.config import global_config
|
||||||
from ...person_info.relationship_manager import relationship_manager
|
from ...person_info.relationship_manager import relationship_manager
|
||||||
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
logger = get_module_logger("prompt")
|
logger = get_module_logger("prompt")
|
||||||
|
|
||||||
@@ -128,7 +131,7 @@ class PromptBuilder:
|
|||||||
# 知识构建
|
# 知识构建
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
prompt_info = ""
|
prompt_info = ""
|
||||||
prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
|
prompt_info = await self.get_prompt_info(message_txt, threshold=0.38)
|
||||||
if prompt_info:
|
if prompt_info:
|
||||||
prompt_info = f"""\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n"""
|
prompt_info = f"""\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n"""
|
||||||
|
|
||||||
@@ -158,16 +161,156 @@ class PromptBuilder:
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
async def get_prompt_info(self, message: str, threshold: float):
|
async def get_prompt_info(self, message: str, threshold: float):
|
||||||
|
start_time = time.time()
|
||||||
related_info = ""
|
related_info = ""
|
||||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
embedding = await get_embedding(message, request_type="prompt_build")
|
|
||||||
related_info += self.get_info_from_db(embedding, limit=1, threshold=threshold)
|
|
||||||
|
|
||||||
|
# 1. 先从LLM获取主题,类似于记忆系统的做法
|
||||||
|
topics = []
|
||||||
|
try:
|
||||||
|
# 先尝试使用记忆系统的方法获取主题
|
||||||
|
hippocampus = HippocampusManager.get_instance()._hippocampus
|
||||||
|
topic_num = min(5, max(1, int(len(message) * 0.1)))
|
||||||
|
topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
|
||||||
|
|
||||||
|
# 提取关键词
|
||||||
|
topics = re.findall(r"<([^>]+)>", topics_response[0])
|
||||||
|
if not topics:
|
||||||
|
topics = []
|
||||||
|
else:
|
||||||
|
topics = [
|
||||||
|
topic.strip()
|
||||||
|
for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||||
|
if topic.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(f"从LLM提取的主题: {', '.join(topics)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从LLM提取主题失败: {str(e)}")
|
||||||
|
# 如果LLM提取失败,使用jieba分词提取关键词作为备选
|
||||||
|
words = jieba.cut(message)
|
||||||
|
topics = [word for word in words if len(word) > 1][:5]
|
||||||
|
logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
|
||||||
|
|
||||||
|
# 如果无法提取到主题,直接使用整个消息
|
||||||
|
if not topics:
|
||||||
|
logger.info("未能提取到任何主题,使用整个消息进行查询")
|
||||||
|
embedding = await get_embedding(message, request_type="prompt_build")
|
||||||
|
if not embedding:
|
||||||
|
logger.error("获取消息嵌入向量失败")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
||||||
|
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
|
||||||
return related_info
|
return related_info
|
||||||
|
|
||||||
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
|
# 2. 对每个主题进行知识库查询
|
||||||
if not query_embedding:
|
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
|
||||||
|
|
||||||
|
# 优化:批量获取嵌入向量,减少API调用
|
||||||
|
embeddings = {}
|
||||||
|
topics_batch = [topic for topic in topics if len(topic) > 0]
|
||||||
|
if message: # 确保消息非空
|
||||||
|
topics_batch.append(message)
|
||||||
|
|
||||||
|
# 批量获取嵌入向量
|
||||||
|
embed_start_time = time.time()
|
||||||
|
for text in topics_batch:
|
||||||
|
if not text or len(text.strip()) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
embedding = await get_embedding(text, request_type="prompt_build")
|
||||||
|
if embedding:
|
||||||
|
embeddings[text] = embedding
|
||||||
|
else:
|
||||||
|
logger.warning(f"获取'{text}'的嵌入向量失败")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
|
||||||
|
|
||||||
|
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
|
||||||
|
|
||||||
|
if not embeddings:
|
||||||
|
logger.error("所有嵌入向量获取失败")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# 3. 对每个主题进行知识库查询
|
||||||
|
all_results = []
|
||||||
|
query_start_time = time.time()
|
||||||
|
|
||||||
|
# 首先添加原始消息的查询结果
|
||||||
|
if message in embeddings:
|
||||||
|
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
|
||||||
|
if original_results:
|
||||||
|
for result in original_results:
|
||||||
|
result["topic"] = "原始消息"
|
||||||
|
all_results.extend(original_results)
|
||||||
|
logger.info(f"原始消息查询到{len(original_results)}条结果")
|
||||||
|
|
||||||
|
# 然后添加每个主题的查询结果
|
||||||
|
for topic in topics:
|
||||||
|
if not topic or topic not in embeddings:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
|
||||||
|
if topic_results:
|
||||||
|
# 添加主题标记
|
||||||
|
for result in topic_results:
|
||||||
|
result["topic"] = topic
|
||||||
|
all_results.extend(topic_results)
|
||||||
|
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
|
||||||
|
|
||||||
|
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
|
||||||
|
|
||||||
|
# 4. 去重和过滤
|
||||||
|
process_start_time = time.time()
|
||||||
|
unique_contents = set()
|
||||||
|
filtered_results = []
|
||||||
|
for result in all_results:
|
||||||
|
content = result["content"]
|
||||||
|
if content not in unique_contents:
|
||||||
|
unique_contents.add(content)
|
||||||
|
filtered_results.append(result)
|
||||||
|
|
||||||
|
# 5. 按相似度排序
|
||||||
|
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
|
||||||
|
|
||||||
|
# 6. 限制总数量(最多10条)
|
||||||
|
filtered_results = filtered_results[:10]
|
||||||
|
logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果")
|
||||||
|
|
||||||
|
# 7. 格式化输出
|
||||||
|
if filtered_results:
|
||||||
|
format_start_time = time.time()
|
||||||
|
grouped_results = {}
|
||||||
|
for result in filtered_results:
|
||||||
|
topic = result["topic"]
|
||||||
|
if topic not in grouped_results:
|
||||||
|
grouped_results[topic] = []
|
||||||
|
grouped_results[topic].append(result)
|
||||||
|
|
||||||
|
# 按主题组织输出
|
||||||
|
for topic, results in grouped_results.items():
|
||||||
|
related_info += f"【主题: {topic}】\n"
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
similarity = result["similarity"]
|
||||||
|
content = result["content"].strip()
|
||||||
|
# 调试:为内容添加序号和相似度信息
|
||||||
|
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
|
||||||
|
related_info += f"{content}\n"
|
||||||
|
related_info += "\n"
|
||||||
|
|
||||||
|
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
|
||||||
|
|
||||||
|
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
|
||||||
|
return related_info
|
||||||
|
|
||||||
|
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]:
|
||||||
|
if not query_embedding:
|
||||||
|
return "" if not return_raw else []
|
||||||
# 使用余弦相似度计算
|
# 使用余弦相似度计算
|
||||||
pipeline = [
|
pipeline = [
|
||||||
{
|
{
|
||||||
@@ -221,11 +364,14 @@ class PromptBuilder:
|
|||||||
]
|
]
|
||||||
|
|
||||||
results = list(db.knowledges.aggregate(pipeline))
|
results = list(db.knowledges.aggregate(pipeline))
|
||||||
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
|
logger.debug(f"知识库查询结果数量: {len(results)}")
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return ""
|
return "" if not return_raw else []
|
||||||
|
|
||||||
|
if return_raw:
|
||||||
|
return results
|
||||||
|
else:
|
||||||
# 返回所有找到的内容,用换行分隔
|
# 返回所有找到的内容,用换行分隔
|
||||||
return "\n".join(str(result["content"]) for result in results)
|
return "\n".join(str(result["content"]) for result in results)
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class KnowledgeLibrary:
|
|||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
def split_content(self, content: str, max_length: int = 512) -> list:
|
def split_content(self, content: str, max_length: int = 512) -> list:
|
||||||
"""将内容分割成适当大小的块,保持段落完整性
|
"""将内容分割成适当大小的块,按空行分割
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 要分割的文本内容
|
content: 要分割的文本内容
|
||||||
@@ -50,66 +50,20 @@ class KnowledgeLibrary:
|
|||||||
Returns:
|
Returns:
|
||||||
list: 分割后的文本块列表
|
list: 分割后的文本块列表
|
||||||
"""
|
"""
|
||||||
# 首先按段落分割
|
# 按空行分割内容
|
||||||
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
|
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
|
||||||
chunks = []
|
chunks = []
|
||||||
current_chunk = []
|
|
||||||
current_length = 0
|
|
||||||
|
|
||||||
for para in paragraphs:
|
for para in paragraphs:
|
||||||
para_length = len(para)
|
para_length = len(para)
|
||||||
|
|
||||||
# 如果单个段落就超过最大长度
|
# 如果段落长度小于等于最大长度,直接添加
|
||||||
if para_length > max_length:
|
if para_length <= max_length:
|
||||||
# 如果当前chunk不为空,先保存
|
chunks.append(para)
|
||||||
if current_chunk:
|
|
||||||
chunks.append("\n".join(current_chunk))
|
|
||||||
current_chunk = []
|
|
||||||
current_length = 0
|
|
||||||
|
|
||||||
# 将长段落按句子分割
|
|
||||||
sentences = [
|
|
||||||
s.strip()
|
|
||||||
for s in para.replace("。", "。\n").replace("!", "!\n").replace("?", "?\n").split("\n")
|
|
||||||
if s.strip()
|
|
||||||
]
|
|
||||||
temp_chunk = []
|
|
||||||
temp_length = 0
|
|
||||||
|
|
||||||
for sentence in sentences:
|
|
||||||
sentence_length = len(sentence)
|
|
||||||
if sentence_length > max_length:
|
|
||||||
# 如果单个句子超长,强制按长度分割
|
|
||||||
if temp_chunk:
|
|
||||||
chunks.append("\n".join(temp_chunk))
|
|
||||||
temp_chunk = []
|
|
||||||
temp_length = 0
|
|
||||||
for i in range(0, len(sentence), max_length):
|
|
||||||
chunks.append(sentence[i : i + max_length])
|
|
||||||
elif temp_length + sentence_length + 1 <= max_length:
|
|
||||||
temp_chunk.append(sentence)
|
|
||||||
temp_length += sentence_length + 1
|
|
||||||
else:
|
else:
|
||||||
chunks.append("\n".join(temp_chunk))
|
# 如果段落超过最大长度,则按最大长度切分
|
||||||
temp_chunk = [sentence]
|
for i in range(0, para_length, max_length):
|
||||||
temp_length = sentence_length
|
chunks.append(para[i:i + max_length])
|
||||||
|
|
||||||
if temp_chunk:
|
|
||||||
chunks.append("\n".join(temp_chunk))
|
|
||||||
|
|
||||||
# 如果当前段落加上现有chunk不超过最大长度
|
|
||||||
elif current_length + para_length + 1 <= max_length:
|
|
||||||
current_chunk.append(para)
|
|
||||||
current_length += para_length + 1
|
|
||||||
else:
|
|
||||||
# 保存当前chunk并开始新的chunk
|
|
||||||
chunks.append("\n".join(current_chunk))
|
|
||||||
current_chunk = [para]
|
|
||||||
current_length = para_length
|
|
||||||
|
|
||||||
# 添加最后一个chunk
|
|
||||||
if current_chunk:
|
|
||||||
chunks.append("\n".join(current_chunk))
|
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user