From afaf24d28c189221b151fff63485fb7ed5ae3bc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Tue, 20 May 2025 23:03:28 +0800 Subject: [PATCH] =?UTF-8?q?update:=20=E9=A3=9E=E6=8E=89=E8=80=81=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/__init__.py | 2 +- .../focus_chat/heartflow_prompt_builder.py | 218 +----------- src/chat/knowledge/src/qa_manager.py | 2 +- src/chat/zhishi/knowledge_library.py | 312 ------------------ 4 files changed, 6 insertions(+), 528 deletions(-) delete mode 100644 src/chat/zhishi/knowledge_library.py diff --git a/src/chat/__init__.py b/src/chat/__init__.py index 931c30ff3..1e859ffb7 100644 --- a/src/chat/__init__.py +++ b/src/chat/__init__.py @@ -1,5 +1,5 @@ """ -MaiMBot插件系统 +MaiBot模块系统 包含聊天、情绪、记忆、日程等功能模块 """ diff --git a/src/chat/focus_chat/heartflow_prompt_builder.py b/src/chat/focus_chat/heartflow_prompt_builder.py index 3209bbe46..b238c7492 100644 --- a/src/chat/focus_chat/heartflow_prompt_builder.py +++ b/src/chat/focus_chat/heartflow_prompt_builder.py @@ -4,17 +4,13 @@ from src.individuality.individuality import individuality from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat from src.chat.person_info.relationship_manager import relationship_manager -from src.chat.utils.utils import get_embedding import time -from typing import Union, Optional +from typing import Optional from src.chat.utils.utils import get_recent_group_speaker from src.manager.mood_manager import mood_manager from src.chat.memory_system.Hippocampus import HippocampusManager from src.chat.knowledge.knowledge_lib import qa_manager import random -import json -import math -from src.common.database.database_model import Knowledges logger = get_logger("prompt") @@ -262,130 +258,6 @@ class PromptBuilder: # --- End choosing template --- return prompt - - async def get_prompt_info_old(self, message: str, threshold: float): - start_time = time.time() - related_info = "" - logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - # 1. 先从LLM获取主题,类似于记忆系统的做法 - topics = [] - - # 如果无法提取到主题,直接使用整个消息 - if not topics: - logger.info("未能提取到任何主题,使用整个消息进行查询") - embedding = await get_embedding(message, request_type="prompt_build") - if not embedding: - logger.error("获取消息嵌入向量失败") - return "" - - related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) - logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}秒") - return related_info - - # 2. 对每个主题进行知识库查询 - logger.info(f"开始处理{len(topics)}个主题的知识库查询") - - # 优化:批量获取嵌入向量,减少API调用 - embeddings = {} - topics_batch = [topic for topic in topics if len(topic) > 0] - if message: # 确保消息非空 - topics_batch.append(message) - - # 批量获取嵌入向量 - embed_start_time = time.time() - for text in topics_batch: - if not text or len(text.strip()) == 0: - continue - - try: - embedding = await get_embedding(text, request_type="prompt_build") - if embedding: - embeddings[text] = embedding - else: - logger.warning(f"获取'{text}'的嵌入向量失败") - except Exception as e: - logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}") - - logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}秒") - - if not embeddings: - logger.error("所有嵌入向量获取失败") - return "" - - # 3. 对每个主题进行知识库查询 - all_results = [] - query_start_time = time.time() - - # 首先添加原始消息的查询结果 - if message in embeddings: - original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True) - if original_results: - for result in original_results: - result["topic"] = "原始消息" - all_results.extend(original_results) - logger.info(f"原始消息查询到{len(original_results)}条结果") - - # 然后添加每个主题的查询结果 - for topic in topics: - if not topic or topic not in embeddings: - continue - - try: - topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True) - if topic_results: - # 添加主题标记 - for result in topic_results: - result["topic"] = topic - all_results.extend(topic_results) - logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果") - except Exception as e: - logger.error(f"查询主题'{topic}'时发生错误: {str(e)}") - - logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果") - - # 4. 去重和过滤 - process_start_time = time.time() - unique_contents = set() - filtered_results = [] - for result in all_results: - content = result["content"] - if content not in unique_contents: - unique_contents.add(content) - filtered_results.append(result) - - # 5. 按相似度排序 - filtered_results.sort(key=lambda x: x["similarity"], reverse=True) - - # 6. 限制总数量(最多10条) - filtered_results = filtered_results[:10] - logger.info( - f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果" - ) - - # 7. 格式化输出 - if filtered_results: - format_start_time = time.time() - grouped_results = {} - for result in filtered_results: - topic = result["topic"] - if topic not in grouped_results: - grouped_results[topic] = [] - grouped_results[topic].append(result) - - # 按主题组织输出 - for topic, results in grouped_results.items(): - related_info += f"【主题: {topic}】\n" - for _i, result in enumerate(results, 1): - _similarity = result["similarity"] - content = result["content"].strip() - related_info += f"{content}\n" - related_info += "\n" - - logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}秒") - - logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}秒") - return related_info - async def get_prompt_info(self, message: str, threshold: float): related_info = "" start_time = time.time() @@ -405,93 +277,11 @@ class PromptBuilder: logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") return related_info else: - logger.debug("从LPMM知识库获取知识失败,使用旧版数据库进行检索") - knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) - related_info += knowledge_from_old - logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") - return related_info + logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...") + return "未检索到知识" except Exception as e: logger.error(f"获取知识库内容时发生异常: {str(e)}") - try: - knowledge_from_old = await self.get_prompt_info_old(message, threshold=threshold) - related_info += knowledge_from_old - logger.debug( - f"异常后使用旧版数据库获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}" - ) - return related_info - except Exception as e2: - logger.error(f"使用旧版数据库获取知识时也发生异常: {str(e2)}") - return "" - - @staticmethod - def get_info_from_db( - query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False - ) -> Union[str, list]: - if not query_embedding: - return "" if not return_raw else [] - - results_with_similarity = [] - try: - # Fetch all knowledge entries - # This might be inefficient for very large databases. - # Consider strategies like FAISS or other vector search libraries if performance becomes an issue. - all_knowledges = Knowledges.select() - - if not all_knowledges: - return [] if return_raw else "" - - query_embedding_magnitude = math.sqrt(sum(x * x for x in query_embedding)) - if query_embedding_magnitude == 0: # Avoid division by zero - return "" if not return_raw else [] - - for knowledge_item in all_knowledges: - try: - db_embedding_str = knowledge_item.embedding - db_embedding = json.loads(db_embedding_str) - - if len(db_embedding) != len(query_embedding): - logger.warning( - f"Embedding length mismatch for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}. Skipping." - ) - continue - - # Calculate Cosine Similarity - dot_product = sum(q * d for q, d in zip(query_embedding, db_embedding)) - db_embedding_magnitude = math.sqrt(sum(x * x for x in db_embedding)) - - if db_embedding_magnitude == 0: # Avoid division by zero - similarity = 0.0 - else: - similarity = dot_product / (query_embedding_magnitude * db_embedding_magnitude) - - if similarity >= threshold: - results_with_similarity.append({"content": knowledge_item.content, "similarity": similarity}) - except json.JSONDecodeError: - logger.error( - f"Failed to parse embedding for knowledge ID {knowledge_item.id if hasattr(knowledge_item, 'id') else 'N/A'}" - ) - except Exception as e: - logger.error(f"Error processing knowledge item: {e}") - - # Sort by similarity in descending order - results_with_similarity.sort(key=lambda x: x["similarity"], reverse=True) - - # Limit results - limited_results = results_with_similarity[:limit] - - logger.debug(f"知识库查询结果数量 (after Peewee processing): {len(limited_results)}") - - if not limited_results: - return "" if not return_raw else [] - - if return_raw: - return limited_results - else: - return "\n".join(str(result["content"]) for result in limited_results) - - except Exception as e: - logger.error(f"Error querying Knowledges with Peewee: {e}") - return "" if not return_raw else [] + return "未检索到知识" init_prompt() diff --git a/src/chat/knowledge/src/qa_manager.py b/src/chat/knowledge/src/qa_manager.py index 11067d0e5..b6bbd1207 100644 --- a/src/chat/knowledge/src/qa_manager.py +++ b/src/chat/knowledge/src/qa_manager.py @@ -121,5 +121,5 @@ class QAManager: found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n" return found_knowledge else: - logger.info("LPMM知识库并未初始化,使用旧版数据库进行检索") + logger.info("LPMM知识库并未初始化,可能是从未导入过知识...") return None diff --git a/src/chat/zhishi/knowledge_library.py b/src/chat/zhishi/knowledge_library.py deleted file mode 100644 index 0068a153c..000000000 --- a/src/chat/zhishi/knowledge_library.py +++ /dev/null @@ -1,312 +0,0 @@ -import os -import sys -import requests -from dotenv import load_dotenv -import hashlib -from datetime import datetime -from tqdm import tqdm -from rich.console import Console -from rich.table import Table -from rich.traceback import install - -install(extra_lines=3) - -# 添加项目根目录到 Python 路径 -root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) -sys.path.append(root_path) - -# 现在可以导入src模块 -from common.database.database import db # noqa E402 - - -# 加载根目录下的env.edv文件 -env_path = os.path.join(root_path, ".env") -if not os.path.exists(env_path): - raise FileNotFoundError(f"配置文件不存在: {env_path}") -load_dotenv(env_path) - - -class KnowledgeLibrary: - def __init__(self): - self.raw_info_dir = "data/raw_info" - self._ensure_dirs() - self.api_key = os.getenv("SILICONFLOW_KEY") - if not self.api_key: - raise ValueError("SILICONFLOW_API_KEY 环境变量未设置") - self.console = Console() - - def _ensure_dirs(self): - """确保必要的目录存在""" - os.makedirs(self.raw_info_dir, exist_ok=True) - - @staticmethod - def read_file(file_path: str) -> str: - """读取文件内容""" - with open(file_path, "r", encoding="utf-8") as f: - return f.read() - - @staticmethod - def split_content(content: str, max_length: int = 512) -> list: - """将内容分割成适当大小的块,按空行分割 - - Args: - content: 要分割的文本内容 - max_length: 每个块的最大长度 - - Returns: - list: 分割后的文本块列表 - """ - # 按空行分割内容 - paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()] - chunks = [] - - for para in paragraphs: - para_length = len(para) - - # 如果段落长度小于等于最大长度,直接添加 - if para_length <= max_length: - chunks.append(para) - else: - # 如果段落超过最大长度,则按最大长度切分 - for i in range(0, para_length, max_length): - chunks.append(para[i : i + max_length]) - - return chunks - - def get_embedding(self, text: str) -> list: - """获取文本的embedding向量""" - url = "https://api.siliconflow.cn/v1/embeddings" - payload = {"model": "BAAI/bge-m3", "input": text, "encoding_format": "float"} - headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - - response = requests.post(url, json=payload, headers=headers) - if response.status_code != 200: - print(f"获取embedding失败: {response.text}") - return None - - return response.json()["data"][0]["embedding"] - - def process_files(self, knowledge_length: int = 512): - """处理raw_info目录下的所有txt文件""" - txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith(".txt")] - - if not txt_files: - self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir)) - self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]") - return - - total_stats = {"processed_files": 0, "total_chunks": 0, "failed_files": [], "skipped_files": []} - - self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]") - - for filename in tqdm(txt_files, desc="处理文件进度"): - file_path = os.path.join(self.raw_info_dir, filename) - result = self.process_single_file(file_path, knowledge_length) - self._update_stats(total_stats, result, filename) - - self._display_processing_results(total_stats) - - def process_single_file(self, file_path: str, knowledge_length: int = 512): - """处理单个文件""" - result = {"status": "success", "chunks_processed": 0, "error": None} - - try: - current_hash = self.calculate_file_hash(file_path) - processed_record = db.processed_files.find_one({"file_path": file_path}) - - if processed_record: - if processed_record.get("hash") == current_hash: - if knowledge_length in processed_record.get("split_by", []): - result["status"] = "skipped" - return result - - content = self.read_file(file_path) - chunks = self.split_content(content, knowledge_length) - - for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False): - embedding = self.get_embedding(chunk) - if embedding: - knowledge = { - "content": chunk, - "embedding": embedding, - "source_file": file_path, - "split_length": knowledge_length, - "created_at": datetime.now(), - } - db.knowledges.insert_one(knowledge) - result["chunks_processed"] += 1 - - split_by = processed_record.get("split_by", []) if processed_record else [] - if knowledge_length not in split_by: - split_by.append(knowledge_length) - - db.knowledges.processed_files.update_one( - {"file_path": file_path}, - {"$set": {"hash": current_hash, "last_processed": datetime.now(), "split_by": split_by}}, - upsert=True, - ) - - except Exception as e: - result["status"] = "failed" - result["error"] = str(e) - - return result - - @staticmethod - def _update_stats(total_stats, result, filename): - """更新总体统计信息""" - if result["status"] == "success": - total_stats["processed_files"] += 1 - total_stats["total_chunks"] += result["chunks_processed"] - elif result["status"] == "failed": - total_stats["failed_files"].append((filename, result["error"])) - elif result["status"] == "skipped": - total_stats["skipped_files"].append(filename) - - def _display_processing_results(self, stats): - """显示处理结果统计""" - self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]") - - table = Table(show_header=True, header_style="bold magenta") - table.add_column("统计项", style="dim") - table.add_column("数值") - - table.add_row("成功处理文件数", str(stats["processed_files"])) - table.add_row("处理的知识块总数", str(stats["total_chunks"])) - table.add_row("跳过的文件数", str(len(stats["skipped_files"]))) - table.add_row("失败的文件数", str(len(stats["failed_files"]))) - - self.console.print(table) - - if stats["failed_files"]: - self.console.print("\n[bold red]处理失败的文件:[/bold red]") - for filename, error in stats["failed_files"]: - self.console.print(f"[red]- {filename}: {error}[/red]") - - if stats["skipped_files"]: - self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]") - for filename in stats["skipped_files"]: - self.console.print(f"[yellow]- {filename}[/yellow]") - - @staticmethod - def calculate_file_hash(file_path): - """计算文件的MD5哈希值""" - hash_md5 = hashlib.md5() - with open(file_path, "rb") as f: - for chunk in iter(lambda: f.read(4096), b""): - hash_md5.update(chunk) - return hash_md5.hexdigest() - - def search_similar_segments(self, query: str, limit: int = 5) -> list: - """搜索与查询文本相似的片段""" - query_embedding = self.get_embedding(query) - if not query_embedding: - return [] - - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - { - "$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]}, - ] - }, - ] - }, - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}, - } - } - }, - } - }, - {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}}, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1, "file_path": 1}}, - ] - - results = list(db.knowledges.aggregate(pipeline)) - return results - - -# 创建单例实例 -knowledge_library = KnowledgeLibrary() - -if __name__ == "__main__": - console = Console() - console.print("[bold green]知识库处理工具[/bold green]") - - while True: - console.print("\n请选择要执行的操作:") - console.print("[1] 麦麦开始学习") - console.print("[2] 麦麦全部忘光光(仅知识)") - console.print("[q] 退出程序") - - choice = input("\n请输入选项: ").strip() - - if choice.lower() == "q": - console.print("[yellow]程序退出[/yellow]") - sys.exit(0) - elif choice == "2": - confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower() - if confirm == "y": - db.knowledges.delete_many({}) - console.print("[green]已清空所有知识![/green]") - continue - elif choice == "1": - if not os.path.exists(knowledge_library.raw_info_dir): - console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]") - os.makedirs(knowledge_library.raw_info_dir, exist_ok=True) - - # 询问分割长度 - while True: - try: - length_input = input("请输入知识分割长度(默认512,输入q退出,回车使用默认值): ").strip() - if length_input.lower() == "q": - break - if not length_input: # 如果直接回车,使用默认值 - knowledge_length = 512 - break - knowledge_length = int(length_input) - if knowledge_length <= 0: - print("分割长度必须大于0,请重新输入") - continue - break - except ValueError: - print("请输入有效的数字") - continue - - if length_input.lower() == "q": - continue - - # 测试知识库功能 - print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...") - knowledge_library.process_files(knowledge_length=knowledge_length) - else: - console.print("[red]无效的选项,请重新选择[/red]") - continue