Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -4,17 +4,13 @@ from src.individuality.individuality import individuality
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.person_info.relationship_manager import relationship_manager
|
||||
from src.chat.utils.utils import get_embedding
|
||||
import time
|
||||
from typing import Union, Optional
|
||||
from typing import Optional
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
import random
|
||||
import json
|
||||
import math
|
||||
from src.common.database.database_model import Knowledges
|
||||
|
||||
|
||||
logger = get_logger("prompt")
|
||||
@@ -263,129 +259,6 @@ class PromptBuilder:
|
||||
|
||||
return prompt
|
||||
|
||||
async def get_prompt_info_old(self, message: str, threshold: float):
|
||||
start_time = time.time()
|
||||
related_info = ""
|
||||
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||
# 1. 先从LLM获取主题,类似于记忆系统的做法
|
||||
topics = []
|
||||
|
||||
# 如果无法提取到主题,直接使用整个消息
|
||||
if not topics:
|
||||
logger.info("未能提取到任何主题,使用整个消息进行查询")
|
||||
embedding = await get_embedding(message, request_type="prompt_build")
|
||||
if not embedding:
|
||||
logger.error("获取消息嵌入向量失败")
|
||||
return ""
|
||||
|
||||
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
|
||||
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒")
|
||||
return related_info
|
||||
|
||||
# 2. 对每个主题进行知识库查询
|
||||
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
|
||||
|
||||
# 优化:批量获取嵌入向量,减少API调用
|
||||
embeddings = {}
|
||||
topics_batch = [topic for topic in topics if len(topic) > 0]
|
||||
if message: # 确保消息非空
|
||||
topics_batch.append(message)
|
||||
|
||||
# 批量获取嵌入向量
|
||||
embed_start_time = time.time()
|
||||
for text in topics_batch:
|
||||
if not text or len(text.strip()) == 0:
|
||||
continue
|
||||
|
||||
try:
|
||||
embedding = await get_embedding(text, request_type="prompt_build")
|
||||
if embedding:
|
||||
embeddings[text] = embedding
|
||||
else:
|
||||
logger.warning(f"获取'{text}'的嵌入向量失败")
|
||||
except Exception as e:
|
||||
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
|
||||
|
||||
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒")
|
||||
|
||||
if not embeddings:
|
||||
logger.error("所有嵌入向量获取失败")
|
||||
return ""
|
||||
|
||||
# 3. 对每个主题进行知识库查询
|
||||
all_results = []
|
||||
query_start_time = time.time()
|
||||
|
||||
# 首先添加原始消息的查询结果
|
||||
if message in embeddings:
|
||||
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
|
||||
if original_results:
|
||||
for result in original_results:
|
||||
result["topic"] = "原始消息"
|
||||
all_results.extend(original_results)
|
||||
logger.info(f"原始消息查询到{len(original_results)}条结果")
|
||||
|
||||
# 然后添加每个主题的查询结果
|
||||
for topic in topics:
|
||||
if not topic or topic not in embeddings:
|
||||
continue
|
||||
|
||||
try:
|
||||
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
|
||||
if topic_results:
|
||||
# 添加主题标记
|
||||
for result in topic_results:
|
||||
result["topic"] = topic
|
||||
all_results.extend(topic_results)
|
||||
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
|
||||
except Exception as e:
|
||||
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
|
||||
|
||||
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
|
||||
|
||||
# 4. 去重和过滤
|
||||
process_start_time = time.time()
|
||||
unique_contents = set()
|
||||
filtered_results = []
|
||||
for result in all_results:
|
||||
content = result["content"]
|
||||
if content not in unique_contents:
|
||||
unique_contents.add(content)
|
||||
filtered_results.append(result)
|
||||
|
||||
# 5. 按相似度排序
|
||||
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
# 6. 限制总数量(最多10条)
|
||||
filtered_results = filtered_results[:10]
|
||||
logger.info(
|
||||
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
|
||||
)
|
||||
|
||||
# 7. 格式化输出
|
||||
if filtered_results:
|
||||
format_start_time = time.time()
|
||||
grouped_results = {}
|
||||
for result in filtered_results:
|
||||
topic = result["topic"]
|
||||
if topic not in grouped_results:
|
||||
grouped_results[topic] = []
|
||||
grouped_results[topic].append(result)
|
||||
|
||||
# 按主题组织输出
|
||||
for topic, results in grouped_results.items():
|
||||
related_info += f"【主题: {topic}】\n"
|
||||
for _i, result in enumerate(results, 1):
|
||||
_similarity = result["similarity"]
|
||||
content = result["content"].strip()
|
||||
related_info += f"{content}\n"
|
||||
related_info += "\n"
|
||||
|
||||
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒")
|
||||
|
||||
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒")
|
||||
return related_info
|
||||
|
||||
async def get_prompt_info(self, message: str, threshold: float):
|
||||
related_info = ""
|
||||
start_time = time.time()
|
||||
@@ -405,93 +278,11 @@ class PromptBuilder:
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
return related_info
|
||||
else:
|
||||
logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索")
|
||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
|
||||
related_info += knowledge_from_old
|
||||
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
|
||||
return related_info
|
||||
logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
|
||||
return "未检索到知识"
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识库内容时发生异常: {str(e)}")
|
||||
try:
|
||||
knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold)
|
||||
related_info += knowledge_from_old
|
||||
logger.debug(
|
||||
f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}"
|
||||
)
|
||||
return related_info
|
||||
except Exception as e2:
|
||||
logger.error(f"使用旧版数据库获取知识时也发生异常: {str(e2)}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def get_info_from_db(
|
||||
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
||||
) -> Union[str, list]:
|
||||
if not query_embedding:
|
||||
return "" if not return_raw else []
|
||||
|
||||
results_with_similarity = []
|
||||
try:
|
||||
# Fetch all knowledge entries
|
||||
# This might be inefficient for very large databases.
|
||||
# Consider strategies like FAISS or other vector search libraries if performance becomes an issue.
|
||||
all_knowledges = Knowledges.select()
|
||||
|
||||
if not all_knowledges:
|
||||
return [] if return_raw else ""
|
||||
|
||||
query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding))
|
||||
if query_embedding_magnitude == 0: # Avoid division by zero
|
||||
return "" if not return_raw else []
|
||||
|
||||
for knowledge_item in all_knowledges:
|
||||
try:
|
||||
db_embedding_str = knowledge_item.embedding
|
||||
db_embedding = json.loads(db_embedding_str)
|
||||
|
||||
if len(db_embedding) != len(query_embedding):
|
||||
logger.warning(
|
||||
f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate Cosine Similarity
|
||||
dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding))
|
||||
db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding))
|
||||
|
||||
if db_embedding_magnitude == 0: # Avoid division by zero
|
||||
similarity = 0.0
|
||||
else:
|
||||
similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude)
|
||||
|
||||
if similarity >= threshold:
|
||||
results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity})
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing knowledge item: {e}")
|
||||
|
||||
# Sort by similarity in descending order
|
||||
results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
|
||||
# Limit results
|
||||
limited_results = results_with_similarity[:limit]
|
||||
|
||||
logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}")
|
||||
|
||||
if not limited_results:
|
||||
return "" if not return_raw else []
|
||||
|
||||
if return_raw:
|
||||
return limited_results
|
||||
else:
|
||||
return "\n".join(str(result["content"]) for result in limited_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying Knowledges with Peewee: {e}")
|
||||
return "" if not return_raw else []
|
||||
return "未检索到知识"
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
Reference in New Issue
Block a user