diff --git a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py index e3015fe1e..d9e2cf75b 100644 --- a/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py +++ b/src/plugins/chat_module/reasoning_chat/reasoning_prompt_builder.py @@ -1,16 +1,19 @@ import random import time -from typing import Optional +from typing import Optional, Union +import re +import jieba +import numpy as np from ....common.database import db -from ...memory_system.Hippocampus import HippocampusManager -from ...moods.moods import MoodManager -from ...schedule.schedule_generator import bot_schedule -from ...config.config import global_config from ...chat.utils import get_embedding, get_recent_group_detailed_plain_text, get_recent_group_speaker from ...chat.chat_stream import chat_manager -from src.common.logger import get_module_logger +from ...moods.moods import MoodManager +from ...memory_system.Hippocampus import HippocampusManager +from ...schedule.schedule_generator import bot_schedule +from ...config.config import global_config from ...person_info.relationship_manager import relationship_manager +from src.common.logger import get_module_logger logger = get_module_logger("prompt") @@ -128,7 +131,7 @@ class PromptBuilder: # 知识构建 start_time = time.time() prompt_info = "" - prompt_info = await self.get_prompt_info(message_txt, threshold=0.5) + prompt_info = await self.get_prompt_info(message_txt, threshold=0.38) if prompt_info: prompt_info = f"""\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n""" @@ -158,16 +161,156 @@ class PromptBuilder: return prompt async def get_prompt_info(self, message: str, threshold: float): + start_time = time.time() related_info = "" logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - embedding = await get_embedding(message, request_type="prompt_build") - related_info += self.get_info_from_db(embedding, limit=1, threshold=threshold) - + + # 1. 先从LLM获取主题,类似于记忆系统的做法 + topics = [] + try: + # 先尝试使用记忆系统的方法获取主题 + hippocampus = HippocampusManager.get_instance()._hippocampus + topic_num = min(5, max(1, int(len(message) * 0.1))) + topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num)) + + # 提取关键词 + topics = re.findall(r"<([^>]+)>", topics_response[0]) + if not topics: + topics = [] + else: + topics = [ + topic.strip() + for topic in ",".join(topics).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") + if topic.strip() + ] + + logger.info(f"从LLM提取的主题: {', '.join(topics)}") + except Exception as e: + logger.error(f"从LLM提取主题失败: {str(e)}") + # 如果LLM提取失败,使用jieba分词提取关键词作为备选 + words = jieba.cut(message) + topics = [word for word in words if len(word) > 1][:5] + logger.info(f"使用jieba提取的主题: {', '.join(topics)}") + + # 如果无法提取到主题,直接使用整个消息 + if not topics: + logger.info("未能提取到任何主题,使用整个消息进行查询") + embedding = await get_embedding(message, request_type="prompt_build") + if not embedding: + logger.error("获取消息嵌入向量失败") + return "" + + related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) + logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒") + return related_info + + # 2. 对每个主题进行知识库查询 + logger.info(f"开始处理{len(topics)}个主题的知识库查询") + + # 优化:批量获取嵌入向量,减少API调用 + embeddings = {} + topics_batch = [topic for topic in topics if len(topic) > 0] + if message: # 确保消息非空 + topics_batch.append(message) + + # 批量获取嵌入向量 + embed_start_time = time.time() + for text in topics_batch: + if not text or len(text.strip()) == 0: + continue + + try: + embedding = await get_embedding(text, request_type="prompt_build") + if embedding: + embeddings[text] = embedding + else: + logger.warning(f"获取'{text}'的嵌入向量失败") + except Exception as e: + logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}") + + logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒") + + if not embeddings: + logger.error("所有嵌入向量获取失败") + return "" + + # 3. 对每个主题进行知识库查询 + all_results = [] + query_start_time = time.time() + + # 首先添加原始消息的查询结果 + if message in embeddings: + original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True) + if original_results: + for result in original_results: + result["topic"] = "原始消息" + all_results.extend(original_results) + logger.info(f"原始消息查询到{len(original_results)}条结果") + + # 然后添加每个主题的查询结果 + for topic in topics: + if not topic or topic not in embeddings: + continue + + try: + topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True) + if topic_results: + # 添加主题标记 + for result in topic_results: + result["topic"] = topic + all_results.extend(topic_results) + logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果") + except Exception as e: + logger.error(f"查询主题'{topic}'时发生错误: {str(e)}") + + logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果") + + # 4. 去重和过滤 + process_start_time = time.time() + unique_contents = set() + filtered_results = [] + for result in all_results: + content = result["content"] + if content not in unique_contents: + unique_contents.add(content) + filtered_results.append(result) + + # 5. 按相似度排序 + filtered_results.sort(key=lambda x: x["similarity"], reverse=True) + + # 6. 限制总数量(最多10条) + filtered_results = filtered_results[:10] + logger.info(f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果") + + # 7. 格式化输出 + if filtered_results: + format_start_time = time.time() + grouped_results = {} + for result in filtered_results: + topic = result["topic"] + if topic not in grouped_results: + grouped_results[topic] = [] + grouped_results[topic].append(result) + + # 按主题组织输出 + for topic, results in grouped_results.items(): + related_info += f"【主题: {topic}】\n" + for i, result in enumerate(results, 1): + similarity = result["similarity"] + content = result["content"].strip() + # 调试:为内容添加序号和相似度信息 + # related_info += f"{i}. [{similarity:.2f}] {content}\n" + related_info += f"{content}\n" + related_info += "\n" + + logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒") + + logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒") return related_info - def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str: + def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False) -> Union[str, list]: if not query_embedding: - return "" + return "" if not return_raw else [] # 使用余弦相似度计算 pipeline = [ { @@ -221,13 +364,16 @@ class PromptBuilder: ] results = list(db.knowledges.aggregate(pipeline)) - # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") + logger.debug(f"知识库查询结果数量: {len(results)}") if not results: - return "" + return "" if not return_raw else [] - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) + if return_raw: + return results + else: + # 返回所有找到的内容,用换行分隔 + return "\n".join(str(result["content"]) for result in results) prompt_builder = PromptBuilder() diff --git a/src/plugins/zhishi/knowledge_library.py b/src/plugins/zhishi/knowledge_library.py index a95a096e6..cf38874ce 100644 --- a/src/plugins/zhishi/knowledge_library.py +++ b/src/plugins/zhishi/knowledge_library.py @@ -41,7 +41,7 @@ class KnowledgeLibrary: return f.read() def split_content(self, content: str, max_length: int = 512) -> list: - """将内容分割成适当大小的块,保持段落完整性 + """将内容分割成适当大小的块,按空行分割 Args: content: 要分割的文本内容 @@ -50,67 +50,21 @@ class KnowledgeLibrary: Returns: list: 分割后的文本块列表 """ - # 首先按段落分割 + # 按空行分割内容 paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()] chunks = [] - current_chunk = [] - current_length = 0 - + for para in paragraphs: para_length = len(para) - - # 如果单个段落就超过最大长度 - if para_length > max_length: - # 如果当前chunk不为空,先保存 - if current_chunk: - chunks.append("\n".join(current_chunk)) - current_chunk = [] - current_length = 0 - - # 将长段落按句子分割 - sentences = [ - s.strip() - for s in para.replace("。", "。\n").replace("!", "!\n").replace("?", "?\n").split("\n") - if s.strip() - ] - temp_chunk = [] - temp_length = 0 - - for sentence in sentences: - sentence_length = len(sentence) - if sentence_length > max_length: - # 如果单个句子超长,强制按长度分割 - if temp_chunk: - chunks.append("\n".join(temp_chunk)) - temp_chunk = [] - temp_length = 0 - for i in range(0, len(sentence), max_length): - chunks.append(sentence[i : i + max_length]) - elif temp_length + sentence_length + 1 <= max_length: - temp_chunk.append(sentence) - temp_length += sentence_length + 1 - else: - chunks.append("\n".join(temp_chunk)) - temp_chunk = [sentence] - temp_length = sentence_length - - if temp_chunk: - chunks.append("\n".join(temp_chunk)) - - # 如果当前段落加上现有chunk不超过最大长度 - elif current_length + para_length + 1 <= max_length: - current_chunk.append(para) - current_length += para_length + 1 + + # 如果段落长度小于等于最大长度,直接添加 + if para_length <= max_length: + chunks.append(para) else: - # 保存当前chunk并开始新的chunk - chunks.append("\n".join(current_chunk)) - current_chunk = [para] - current_length = para_length - - # 添加最后一个chunk - if current_chunk: - chunks.append("\n".join(current_chunk)) - + # 如果段落超过最大长度,则按最大长度切分 + for i in range(0, para_length, max_length): + chunks.append(para[i:i + max_length]) + return chunks def get_embedding(self, text: str) -> list: