diff --git a/.gitignore b/.gitignore index e51abc5cc..6e1be60b4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ data/ +data1/ mongodb/ NapCat.Framework.Windows.Once/ log/ diff --git a/run.py b/run.py index 50e312c37..cfd3a5f14 100644 --- a/run.py +++ b/run.py @@ -128,13 +128,17 @@ if __name__ == "__main__": ) os.system("cls") if choice == "1": - install_napcat() - install_mongodb() + confirm = input("首次安装将下载并配置所需组件\n1.确认\n2.取消\n") + if confirm == "1": + install_napcat() + install_mongodb() + else: + print("已取消安装") elif choice == "2": run_maimbot() - choice = input("是否启动推理可视化?(y/N)").upper() + choice = input("是否启动推理可视化?(未完善)(y/N)").upper() if choice == "Y": run_cmd(r"python src\gui\reasoning_gui.py") - choice = input("是否启动记忆可视化?(y/N)").upper() + choice = input("是否启动记忆可视化?(未完善)(y/N)").upper() if choice == "Y": run_cmd(r"python src/plugins/memory_system/memory_manual_build.py") diff --git a/src/common/database.py b/src/common/database.py index f0954b07c..d592b0f90 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -1,8 +1,6 @@ from typing import Optional - from pymongo import MongoClient - class Database: _instance: Optional["Database"] = None @@ -50,25 +48,4 @@ class Database: def get_instance(cls) -> "Database": if cls._instance is None: raise RuntimeError("Database not initialized") - return cls._instance - - - #测试用 - - def get_random_group_messages(self, group_id: str, limit: int = 5): - # 先随机获取一条消息 - random_message = list(self.db.messages.aggregate([ - {"$match": {"group_id": group_id}}, - {"$sample": {"size": 1}} - ]))[0] - - # 获取该消息之后的消息 - subsequent_messages = list(self.db.messages.find({ - "group_id": group_id, - "time": {"$gt": random_message["time"]} - }).sort("time", 1).limit(limit)) - - # 将随机消息和后续消息合并 - messages = [random_message] + subsequent_messages - - return messages \ No newline at end of file + return cls._instance \ No newline at end of file diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 9b2ac06f1..5bd502a7e 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -7,7 +7,6 @@ from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent from ..memory_system.memory import hippocampus from ..moods.moods import MoodManager # 导入情绪管理器 from .config import global_config -from .cq_code import CQCode, cq_code_tool # 导入CQCode模块 from .emoji_manager import emoji_manager # 导入表情包管理器 from .llm_generator import ResponseGenerator from .message import MessageSending, MessageRecv, MessageThinking, MessageSet @@ -218,7 +217,7 @@ class ChatBot: # message_set 可以直接加入 message_manager # print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器") - print(f"添加message_set到message_manager") + print("添加message_set到message_manager") message_manager.add_message(message_set) bot_response_time = thinking_time_point diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index af7334afe..46dc34e92 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -8,7 +8,7 @@ from loguru import logger from ...common.database import Database from ..models.utils_model import LLM_request from .config import global_config -from .message import MessageRecv, MessageThinking, MessageSending,Message +from .message import MessageRecv, MessageThinking, Message from .prompt_builder import prompt_builder from .relationship_manager import relationship_manager from .utils import process_llm_response diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index 0505c05a6..d848f068f 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -3,14 +3,14 @@ import html import re import json from dataclasses import dataclass -from typing import Dict, ForwardRef, List, Optional, Union +from typing import Dict, List, Optional import urllib3 from loguru import logger from .utils_image import image_manager -from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase -from .chat_stream import ChatStream, chat_manager +from .message_base import Seg, UserInfo, BaseMessageInfo, MessageBase +from .chat_stream import ChatStream # 禁用SSL警告 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) diff --git a/src/plugins/chat/message_base.py b/src/plugins/chat/message_base.py index d17c2c357..ae7ec3872 100644 --- a/src/plugins/chat/message_base.py +++ b/src/plugins/chat/message_base.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, asdict -from typing import List, Optional, Union, Any, Dict +from typing import List, Optional, Union, Dict @dataclass class Seg: diff --git a/src/plugins/chat/message_cq.py b/src/plugins/chat/message_cq.py index 6bfa47c3f..cb47ae4b3 100644 --- a/src/plugins/chat/message_cq.py +++ b/src/plugins/chat/message_cq.py @@ -1,12 +1,12 @@ import time from dataclasses import dataclass -from typing import Dict, ForwardRef, List, Optional, Union +from typing import Dict, Optional import urllib3 -from .cq_code import CQCode, cq_code_tool +from .cq_code import cq_code_tool from .utils_cq import parse_cq_code -from .utils_user import get_groupname, get_user_cardname, get_user_nickname +from .utils_user import get_groupname from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase # 禁用SSL警告 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py index 9db74633f..eefa6f4ae 100644 --- a/src/plugins/chat/message_sender.py +++ b/src/plugins/chat/message_sender.py @@ -5,12 +5,10 @@ from typing import Dict, List, Optional, Union from loguru import logger from nonebot.adapters.onebot.v11 import Bot -from .cq_code import cq_code_tool from .message_cq import MessageSendCQ -from .message import MessageSending, MessageThinking, MessageRecv,MessageSet +from .message import MessageSending, MessageThinking, MessageSet from .storage import MessageStorage from .config import global_config -from .chat_stream import chat_manager class Message_Sender: diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index fec6c7926..b97666763 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -9,7 +9,7 @@ from ..moods.moods import MoodManager from ..schedule.schedule_generator import bot_schedule from .config import global_config from .utils import get_embedding, get_recent_group_detailed_plain_text -from .chat_stream import ChatStream, chat_manager +from .chat_stream import chat_manager class PromptBuilder: diff --git a/src/plugins/chat/relationship_manager.py b/src/plugins/chat/relationship_manager.py index 9e7cafda0..90e92e7b6 100644 --- a/src/plugins/chat/relationship_manager.py +++ b/src/plugins/chat/relationship_manager.py @@ -1,6 +1,5 @@ import asyncio -from typing import Optional, Union -from typing import Optional, Union +from typing import Optional from loguru import logger from ...common.database import Database diff --git a/src/plugins/chat/storage.py b/src/plugins/chat/storage.py index f403b2c8b..c3986a2d0 100644 --- a/src/plugins/chat/storage.py +++ b/src/plugins/chat/storage.py @@ -1,8 +1,6 @@ from typing import Optional, Union -from typing import Optional, Union from ...common.database import Database -from .message_base import MessageBase from .message import MessageSending, MessageRecv from .chat_stream import ChatStream from loguru import logger diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 7c658fbf7..186f2ab79 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -12,8 +12,8 @@ from loguru import logger from ..models.utils_model import LLM_request from ..utils.typo_generator import ChineseTypoGenerator from .config import global_config -from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message -from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo +from .message import MessageRecv,Message +from .message_base import UserInfo from .chat_stream import ChatStream from ..moods.moods import MoodManager diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 25f23359b..42d5f9efc 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -1,16 +1,12 @@ import base64 -import io import os import time -import zlib import aiohttp import hashlib -from typing import Optional, Tuple, Union -from urllib.parse import urlparse +from typing import Optional, Union from loguru import logger from nonebot import get_driver -from PIL import Image from ...common.database import Database from ..chat.config import global_config diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py index 39083f0b8..f34afb746 100644 --- a/src/plugins/chat/willing_manager.py +++ b/src/plugins/chat/willing_manager.py @@ -1,13 +1,9 @@ import asyncio from typing import Dict -from loguru import logger -from typing import Dict -from loguru import logger from .config import global_config -from .message_base import UserInfo, GroupInfo -from .chat_stream import chat_manager,ChatStream +from .chat_stream import ChatStream class WillingManager: diff --git a/src/plugins/knowledege/knowledge_library.py b/src/plugins/knowledege/knowledge_library.py deleted file mode 100644 index e9d7167fd..000000000 --- a/src/plugins/knowledege/knowledge_library.py +++ /dev/null @@ -1,199 +0,0 @@ -import os -import sys -import time - -import requests -from dotenv import load_dotenv - -# 添加项目根目录到 Python 路径 -root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) -sys.path.append(root_path) - -# 加载根目录下的env.edv文件 -env_path = os.path.join(root_path, ".env.dev") -if not os.path.exists(env_path): - raise FileNotFoundError(f"配置文件不存在: {env_path}") -load_dotenv(env_path) - -from src.common.database import Database - -# 从环境变量获取配置 -Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), -) - -class KnowledgeLibrary: - def __init__(self): - self.db = Database.get_instance() - self.raw_info_dir = "data/raw_info" - self._ensure_dirs() - self.api_key = os.getenv("SILICONFLOW_KEY") - if not self.api_key: - raise ValueError("SILICONFLOW_API_KEY 环境变量未设置") - - def _ensure_dirs(self): - """确保必要的目录存在""" - os.makedirs(self.raw_info_dir, exist_ok=True) - - def get_embedding(self, text: str) -> list: - """获取文本的embedding向量""" - url = "https://api.siliconflow.cn/v1/embeddings" - payload = { - "model": "BAAI/bge-m3", - "input": text, - "encoding_format": "float" - } - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - response = requests.post(url, json=payload, headers=headers) - if response.status_code != 200: - print(f"获取embedding失败: {response.text}") - return None - - return response.json()['data'][0]['embedding'] - - def process_files(self): - """处理raw_info目录下的所有txt文件""" - for filename in os.listdir(self.raw_info_dir): - if filename.endswith('.txt'): - file_path = os.path.join(self.raw_info_dir, filename) - self.process_single_file(file_path) - - def process_single_file(self, file_path: str): - """处理单个文件""" - try: - # 检查文件是否已处理 - if self.db.db.processed_files.find_one({"file_path": file_path}): - print(f"文件已处理过,跳过: {file_path}") - return - - with open(file_path, 'r', encoding='utf-8') as f: - content = f.read() - - # 按1024字符分段 - segments = [content[i:i+600] for i in range(0, len(content), 300)] - - # 处理每个分段 - for segment in segments: - if not segment.strip(): # 跳过空段 - continue - - # 获取embedding - embedding = self.get_embedding(segment) - if not embedding: - continue - - # 存储到数据库 - doc = { - "content": segment, - "embedding": embedding, - "file_path": file_path, - "segment_length": len(segment) - } - - # 使用文本内容的哈希值作为唯一标识 - content_hash = hash(segment) - - # 更新或插入文档 - self.db.db.knowledges.update_one( - {"content_hash": content_hash}, - {"$set": doc}, - upsert=True - ) - - # 记录文件已处理 - self.db.db.processed_files.insert_one({ - "file_path": file_path, - "processed_time": time.time() - }) - - print(f"成功处理文件: {file_path}") - - except Exception as e: - print(f"处理文件 {file_path} 时出错: {str(e)}") - - def search_similar_segments(self, query: str, limit: int = 5) -> list: - """搜索与查询文本相似的片段""" - query_embedding = self.get_embedding(query) - if not query_embedding: - return [] - - # 使用余弦相似度计算 - pipeline = [ - { - "$addFields": { - "dotProduct": { - "$reduce": { - "input": {"$range": [0, {"$size": "$embedding"}]}, - "initialValue": 0, - "in": { - "$add": [ - "$$value", - {"$multiply": [ - {"$arrayElemAt": ["$embedding", "$$this"]}, - {"$arrayElemAt": [query_embedding, "$$this"]} - ]} - ] - } - } - }, - "magnitude1": { - "$sqrt": { - "$reduce": { - "input": "$embedding", - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} - } - } - }, - "magnitude2": { - "$sqrt": { - "$reduce": { - "input": query_embedding, - "initialValue": 0, - "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} - } - } - } - } - }, - { - "$addFields": { - "similarity": { - "$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}] - } - } - }, - {"$sort": {"similarity": -1}}, - {"$limit": limit}, - {"$project": {"content": 1, "similarity": 1, "file_path": 1}} - ] - - results = list(self.db.db.knowledges.aggregate(pipeline)) - return results - -# 创建单例实例 -knowledge_library = KnowledgeLibrary() - -if __name__ == "__main__": - # 测试知识库功能 - print("开始处理知识库文件...") - knowledge_library.process_files() - - # 测试搜索功能 - test_query = "麦麦评价一下僕と花" - print(f"\n搜索与'{test_query}'相似的内容:") - results = knowledge_library.search_similar_segments(test_query) - for result in results: - print(f"相似度: {result['similarity']:.4f}") - print(f"内容: {result['content'][:100]}...") - print("-" * 50) diff --git a/src/plugins/memory_system/memory_manual_build.py b/src/plugins/memory_system/memory_manual_build.py index 9c1d43ce9..736a50e97 100644 --- a/src/plugins/memory_system/memory_manual_build.py +++ b/src/plugins/memory_system/memory_manual_build.py @@ -10,7 +10,6 @@ from pathlib import Path import matplotlib.pyplot as plt import networkx as nx -import pymongo from dotenv import load_dotenv from loguru import logger import jieba diff --git a/src/plugins/zhishi/knowledge_library.py b/src/plugins/zhishi/knowledge_library.py new file mode 100644 index 000000000..2411e3112 --- /dev/null +++ b/src/plugins/zhishi/knowledge_library.py @@ -0,0 +1,383 @@ +import os +import sys +import time +import requests +from dotenv import load_dotenv +import hashlib +from datetime import datetime +from tqdm import tqdm +from rich.console import Console +from rich.table import Table + +# 添加项目根目录到 Python 路径 +root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +sys.path.append(root_path) + +# 现在可以导入src模块 +from src.common.database import Database + +# 加载根目录下的env.edv文件 +env_path = os.path.join(root_path, ".env.prod") +if not os.path.exists(env_path): + raise FileNotFoundError(f"配置文件不存在: {env_path}") +load_dotenv(env_path) + +class KnowledgeLibrary: + def __init__(self): + # 初始化数据库连接 + if Database._instance is None: + Database.initialize( + uri=os.getenv("MONGODB_URI"), + host=os.getenv("MONGODB_HOST", "127.0.0.1"), + port=int(os.getenv("MONGODB_PORT", "27017")), + db_name=os.getenv("DATABASE_NAME", "MegBot"), + username=os.getenv("MONGODB_USERNAME"), + password=os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE"), + ) + self.db = Database.get_instance() + self.raw_info_dir = "data/raw_info" + self._ensure_dirs() + self.api_key = os.getenv("SILICONFLOW_KEY") + if not self.api_key: + raise ValueError("SILICONFLOW_API_KEY 环境变量未设置") + self.console = Console() + + def _ensure_dirs(self): + """确保必要的目录存在""" + os.makedirs(self.raw_info_dir, exist_ok=True) + + def read_file(self, file_path: str) -> str: + """读取文件内容""" + with open(file_path, 'r', encoding='utf-8') as f: + return f.read() + + def split_content(self, content: str, max_length: int = 512) -> list: + """将内容分割成适当大小的块,保持段落完整性 + + Args: + content: 要分割的文本内容 + max_length: 每个块的最大长度 + + Returns: + list: 分割后的文本块列表 + """ + # 首先按段落分割 + paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()] + chunks = [] + current_chunk = [] + current_length = 0 + + for para in paragraphs: + para_length = len(para) + + # 如果单个段落就超过最大长度 + if para_length > max_length: + # 如果当前chunk不为空,先保存 + if current_chunk: + chunks.append('\n'.join(current_chunk)) + current_chunk = [] + current_length = 0 + + # 将长段落按句子分割 + sentences = [s.strip() for s in para.replace('。', '。\n').replace('!', '!\n').replace('?', '?\n').split('\n') if s.strip()] + temp_chunk = [] + temp_length = 0 + + for sentence in sentences: + sentence_length = len(sentence) + if sentence_length > max_length: + # 如果单个句子超长,强制按长度分割 + if temp_chunk: + chunks.append('\n'.join(temp_chunk)) + temp_chunk = [] + temp_length = 0 + for i in range(0, len(sentence), max_length): + chunks.append(sentence[i:i + max_length]) + elif temp_length + sentence_length + 1 <= max_length: + temp_chunk.append(sentence) + temp_length += sentence_length + 1 + else: + chunks.append('\n'.join(temp_chunk)) + temp_chunk = [sentence] + temp_length = sentence_length + + if temp_chunk: + chunks.append('\n'.join(temp_chunk)) + + # 如果当前段落加上现有chunk不超过最大长度 + elif current_length + para_length + 1 <= max_length: + current_chunk.append(para) + current_length += para_length + 1 + else: + # 保存当前chunk并开始新的chunk + chunks.append('\n'.join(current_chunk)) + current_chunk = [para] + current_length = para_length + + # 添加最后一个chunk + if current_chunk: + chunks.append('\n'.join(current_chunk)) + + return chunks + + def get_embedding(self, text: str) -> list: + """获取文本的embedding向量""" + url = "https://api.siliconflow.cn/v1/embeddings" + payload = { + "model": "BAAI/bge-m3", + "input": text, + "encoding_format": "float" + } + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + response = requests.post(url, json=payload, headers=headers) + if response.status_code != 200: + print(f"获取embedding失败: {response.text}") + return None + + return response.json()['data'][0]['embedding'] + + def process_files(self, knowledge_length:int=512): + """处理raw_info目录下的所有txt文件""" + txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith('.txt')] + + if not txt_files: + self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir)) + self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]") + return + + total_stats = { + "processed_files": 0, + "total_chunks": 0, + "failed_files": [], + "skipped_files": [] + } + + self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]") + + for filename in tqdm(txt_files, desc="处理文件进度"): + file_path = os.path.join(self.raw_info_dir, filename) + result = self.process_single_file(file_path, knowledge_length) + self._update_stats(total_stats, result, filename) + + self._display_processing_results(total_stats) + + def process_single_file(self, file_path: str, knowledge_length: int = 512): + """处理单个文件""" + result = { + "status": "success", + "chunks_processed": 0, + "error": None + } + + try: + current_hash = self.calculate_file_hash(file_path) + processed_record = self.db.db.processed_files.find_one({"file_path": file_path}) + + if processed_record: + if processed_record.get("hash") == current_hash: + if knowledge_length in processed_record.get("split_by", []): + result["status"] = "skipped" + return result + + content = self.read_file(file_path) + chunks = self.split_content(content, knowledge_length) + + for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False): + embedding = self.get_embedding(chunk) + if embedding: + knowledge = { + "content": chunk, + "embedding": embedding, + "source_file": file_path, + "split_length": knowledge_length, + "created_at": datetime.now() + } + self.db.db.knowledges.insert_one(knowledge) + result["chunks_processed"] += 1 + + split_by = processed_record.get("split_by", []) if processed_record else [] + if knowledge_length not in split_by: + split_by.append(knowledge_length) + + self.db.db.processed_files.update_one( + {"file_path": file_path}, + { + "$set": { + "hash": current_hash, + "last_processed": datetime.now(), + "split_by": split_by + } + }, + upsert=True + ) + + except Exception as e: + result["status"] = "failed" + result["error"] = str(e) + + return result + + def _update_stats(self, total_stats, result, filename): + """更新总体统计信息""" + if result["status"] == "success": + total_stats["processed_files"] += 1 + total_stats["total_chunks"] += result["chunks_processed"] + elif result["status"] == "failed": + total_stats["failed_files"].append((filename, result["error"])) + elif result["status"] == "skipped": + total_stats["skipped_files"].append(filename) + + def _display_processing_results(self, stats): + """显示处理结果统计""" + self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]") + + table = Table(show_header=True, header_style="bold magenta") + table.add_column("统计项", style="dim") + table.add_column("数值") + + table.add_row("成功处理文件数", str(stats["processed_files"])) + table.add_row("处理的知识块总数", str(stats["total_chunks"])) + table.add_row("跳过的文件数", str(len(stats["skipped_files"]))) + table.add_row("失败的文件数", str(len(stats["failed_files"]))) + + self.console.print(table) + + if stats["failed_files"]: + self.console.print("\n[bold red]处理失败的文件:[/bold red]") + for filename, error in stats["failed_files"]: + self.console.print(f"[red]- {filename}: {error}[/red]") + + if stats["skipped_files"]: + self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]") + for filename in stats["skipped_files"]: + self.console.print(f"[yellow]- {filename}[/yellow]") + + def calculate_file_hash(self, file_path): + """计算文件的MD5哈希值""" + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + def search_similar_segments(self, query: str, limit: int = 5) -> list: + """搜索与查询文本相似的片段""" + query_embedding = self.get_embedding(query) + if not query_embedding: + return [] + + # 使用余弦相似度计算 + pipeline = [ + { + "$addFields": { + "dotProduct": { + "$reduce": { + "input": {"$range": [0, {"$size": "$embedding"}]}, + "initialValue": 0, + "in": { + "$add": [ + "$$value", + {"$multiply": [ + {"$arrayElemAt": ["$embedding", "$$this"]}, + {"$arrayElemAt": [query_embedding, "$$this"]} + ]} + ] + } + } + }, + "magnitude1": { + "$sqrt": { + "$reduce": { + "input": "$embedding", + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} + } + } + }, + "magnitude2": { + "$sqrt": { + "$reduce": { + "input": query_embedding, + "initialValue": 0, + "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]} + } + } + } + } + }, + { + "$addFields": { + "similarity": { + "$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}] + } + } + }, + {"$sort": {"similarity": -1}}, + {"$limit": limit}, + {"$project": {"content": 1, "similarity": 1, "file_path": 1}} + ] + + results = list(self.db.db.knowledges.aggregate(pipeline)) + return results + +# 创建单例实例 +knowledge_library = KnowledgeLibrary() + +if __name__ == "__main__": + console = Console() + console.print("[bold green]知识库处理工具[/bold green]") + + while True: + console.print("\n请选择要执行的操作:") + console.print("[1] 麦麦开始学习") + console.print("[2] 麦麦全部忘光光(仅知识)") + console.print("[q] 退出程序") + + choice = input("\n请输入选项: ").strip() + + if choice.lower() == 'q': + console.print("[yellow]程序退出[/yellow]") + sys.exit(0) + elif choice == '2': + confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower() + if confirm == 'y': + knowledge_library.db.db.knowledges.delete_many({}) + console.print("[green]已清空所有知识![/green]") + continue + elif choice == '1': + if not os.path.exists(knowledge_library.raw_info_dir): + console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]") + os.makedirs(knowledge_library.raw_info_dir, exist_ok=True) + + # 询问分割长度 + while True: + try: + length_input = input("请输入知识分割长度(默认512,输入q退出,回车使用默认值): ").strip() + if length_input.lower() == 'q': + break + if not length_input: # 如果直接回车,使用默认值 + knowledge_length = 512 + break + knowledge_length = int(length_input) + if knowledge_length <= 0: + print("分割长度必须大于0,请重新输入") + continue + break + except ValueError: + print("请输入有效的数字") + continue + + if length_input.lower() == 'q': + continue + + # 测试知识库功能 + print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...") + knowledge_library.process_files(knowledge_length=knowledge_length) + else: + console.print("[red]无效的选项,请重新选择[/red]") + continue diff --git a/麦麦开始学习.bat b/麦麦开始学习.bat new file mode 100644 index 000000000..f7391150f --- /dev/null +++ b/麦麦开始学习.bat @@ -0,0 +1,45 @@ +@echo off +setlocal enabledelayedexpansion +chcp 65001 +cd /d %~dp0 + +echo ===================================== +echo 选择Python环境: +echo 1 - venv (推荐) +echo 2 - conda +echo ===================================== +choice /c 12 /n /m "输入数字(1或2): " + +if errorlevel 2 ( + echo ===================================== + set "CONDA_ENV=" + set /p CONDA_ENV="请输入要激活的 conda 环境名称: " + + :: 检查输入是否为空 + if "!CONDA_ENV!"=="" ( + echo 错误:环境名称不能为空 + pause + exit /b 1 + ) + + call conda activate !CONDA_ENV! + if errorlevel 1 ( + echo 激活 conda 环境失败 + pause + exit /b 1 + ) + + echo Conda 环境 "!CONDA_ENV!" 激活成功 + python src/plugins/zhishi/knowledge_library.py +) else ( + if exist "venv\Scripts\python.exe" ( + venv\Scripts\python src/plugins/zhishi/knowledge_library.py + ) else ( + echo ===================================== + echo 错误: venv环境不存在,请先创建虚拟环境 + pause + exit /b 1 + ) +) +endlocal +pause