新增了知识库一键启动漂亮脚本

This commit is contained in:
SengokuCola
2025-03-11 23:46:49 +08:00
parent 8d68e63835
commit ed18f2e96d
19 changed files with 454 additions and 258 deletions

1
.gitignore vendored
View File

@@ -1,4 +1,5 @@
data/ data/
data1/
mongodb/ mongodb/
NapCat.Framework.Windows.Once/ NapCat.Framework.Windows.Once/
log/ log/

12
run.py
View File

@@ -128,13 +128,17 @@ if __name__ == "__main__":
) )
os.system("cls") os.system("cls")
if choice == "1": if choice == "1":
install_napcat() confirm = input("首次安装将下载并配置所需组件\n1.确认\n2.取消\n")
install_mongodb() if confirm == "1":
install_napcat()
install_mongodb()
else:
print("已取消安装")
elif choice == "2": elif choice == "2":
run_maimbot() run_maimbot()
choice = input("是否启动推理可视化y/N").upper() choice = input("是否启动推理可视化?(未完善)(y/N").upper()
if choice == "Y": if choice == "Y":
run_cmd(r"python src\gui\reasoning_gui.py") run_cmd(r"python src\gui\reasoning_gui.py")
choice = input("是否启动记忆可视化y/N").upper() choice = input("是否启动记忆可视化?(未完善)(y/N").upper()
if choice == "Y": if choice == "Y":
run_cmd(r"python src/plugins/memory_system/memory_manual_build.py") run_cmd(r"python src/plugins/memory_system/memory_manual_build.py")

View File

@@ -1,8 +1,6 @@
from typing import Optional from typing import Optional
from pymongo import MongoClient from pymongo import MongoClient
class Database: class Database:
_instance: Optional["Database"] = None _instance: Optional["Database"] = None
@@ -50,25 +48,4 @@ class Database:
def get_instance(cls) -> "Database": def get_instance(cls) -> "Database":
if cls._instance is None: if cls._instance is None:
raise RuntimeError("Database not initialized") raise RuntimeError("Database not initialized")
return cls._instance return cls._instance
#测试用
def get_random_group_messages(self, group_id: str, limit: int = 5):
# 先随机获取一条消息
random_message = list(self.db.messages.aggregate([
{"$match": {"group_id": group_id}},
{"$sample": {"size": 1}}
]))[0]
# 获取该消息之后的消息
subsequent_messages = list(self.db.messages.find({
"group_id": group_id,
"time": {"$gt": random_message["time"]}
}).sort("time", 1).limit(limit))
# 将随机消息和后续消息合并
messages = [random_message] + subsequent_messages
return messages

View File

@@ -7,7 +7,6 @@ from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
from ..memory_system.memory import hippocampus from ..memory_system.memory import hippocampus
from ..moods.moods import MoodManager # 导入情绪管理器 from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config from .config import global_config
from .cq_code import CQCode, cq_code_tool # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器 from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator from .llm_generator import ResponseGenerator
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
@@ -218,7 +217,7 @@ class ChatBot:
# message_set 可以直接加入 message_manager # message_set 可以直接加入 message_manager
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器") # print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
print(f"添加message_set到message_manager") print("添加message_set到message_manager")
message_manager.add_message(message_set) message_manager.add_message(message_set)
bot_response_time = thinking_time_point bot_response_time = thinking_time_point

View File

@@ -8,7 +8,7 @@ from loguru import logger
from ...common.database import Database from ...common.database import Database
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .message import MessageRecv, MessageThinking, MessageSending,Message from .message import MessageRecv, MessageThinking, Message
from .prompt_builder import prompt_builder from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .utils import process_llm_response from .utils import process_llm_response

View File

@@ -3,14 +3,14 @@ import html
import re import re
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional, Union from typing import Dict, List, Optional
import urllib3 import urllib3
from loguru import logger from loguru import logger
from .utils_image import image_manager from .utils_image import image_manager
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase from .message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream, chat_manager from .chat_stream import ChatStream
# 禁用SSL警告 # 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Any, Dict from typing import List, Optional, Union, Dict
@dataclass @dataclass
class Seg: class Seg:

View File

@@ -1,12 +1,12 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional, Union from typing import Dict, Optional
import urllib3 import urllib3
from .cq_code import CQCode, cq_code_tool from .cq_code import cq_code_tool
from .utils_cq import parse_cq_code from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname from .utils_user import get_groupname
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告 # 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

View File

@@ -5,12 +5,10 @@ from typing import Dict, List, Optional, Union
from loguru import logger from loguru import logger
from nonebot.adapters.onebot.v11 import Bot from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool
from .message_cq import MessageSendCQ from .message_cq import MessageSendCQ
from .message import MessageSending, MessageThinking, MessageRecv,MessageSet from .message import MessageSending, MessageThinking, MessageSet
from .storage import MessageStorage from .storage import MessageStorage
from .config import global_config from .config import global_config
from .chat_stream import chat_manager
class Message_Sender: class Message_Sender:

View File

@@ -9,7 +9,7 @@ from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
from .config import global_config from .config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text from .utils import get_embedding, get_recent_group_detailed_plain_text
from .chat_stream import ChatStream, chat_manager from .chat_stream import chat_manager
class PromptBuilder: class PromptBuilder:

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
from typing import Optional, Union from typing import Optional
from typing import Optional, Union
from loguru import logger from loguru import logger
from ...common.database import Database from ...common.database import Database

View File

@@ -1,8 +1,6 @@
from typing import Optional, Union from typing import Optional, Union
from typing import Optional, Union
from ...common.database import Database from ...common.database import Database
from .message_base import MessageBase
from .message import MessageSending, MessageRecv from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream from .chat_stream import ChatStream
from loguru import logger from loguru import logger

View File

@@ -12,8 +12,8 @@ from loguru import logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config from .config import global_config
from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message from .message import MessageRecv,Message
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo from .message_base import UserInfo
from .chat_stream import ChatStream from .chat_stream import ChatStream
from ..moods.moods import MoodManager from ..moods.moods import MoodManager

View File

@@ -1,16 +1,12 @@
import base64 import base64
import io
import os import os
import time import time
import zlib
import aiohttp import aiohttp
import hashlib import hashlib
from typing import Optional, Tuple, Union from typing import Optional, Union
from urllib.parse import urlparse
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from PIL import Image
from ...common.database import Database from ...common.database import Database
from ..chat.config import global_config from ..chat.config import global_config

View File

@@ -1,13 +1,9 @@
import asyncio import asyncio
from typing import Dict from typing import Dict
from loguru import logger
from typing import Dict
from loguru import logger
from .config import global_config from .config import global_config
from .message_base import UserInfo, GroupInfo from .chat_stream import ChatStream
from .chat_stream import chat_manager,ChatStream
class WillingManager: class WillingManager:

View File

@@ -1,199 +0,0 @@
import os
import sys
import time
import requests
from dotenv import load_dotenv
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.dev")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
from src.common.database import Database
# 从环境变量获取配置
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class KnowledgeLibrary:
def __init__(self):
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()['data'][0]['embedding']
def process_files(self):
"""处理raw_info目录下的所有txt文件"""
for filename in os.listdir(self.raw_info_dir):
if filename.endswith('.txt'):
file_path = os.path.join(self.raw_info_dir, filename)
self.process_single_file(file_path)
def process_single_file(self, file_path: str):
"""处理单个文件"""
try:
# 检查文件是否已处理
if self.db.db.processed_files.find_one({"file_path": file_path}):
print(f"文件已处理过,跳过: {file_path}")
return
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# 按1024字符分段
segments = [content[i:i+600] for i in range(0, len(content), 300)]
# 处理每个分段
for segment in segments:
if not segment.strip(): # 跳过空段
continue
# 获取embedding
embedding = self.get_embedding(segment)
if not embedding:
continue
# 存储到数据库
doc = {
"content": segment,
"embedding": embedding,
"file_path": file_path,
"segment_length": len(segment)
}
# 使用文本内容的哈希值作为唯一标识
content_hash = hash(segment)
# 更新或插入文档
self.db.db.knowledges.update_one(
{"content_hash": content_hash},
{"$set": doc},
upsert=True
)
# 记录文件已处理
self.db.db.processed_files.insert_one({
"file_path": file_path,
"processed_time": time.time()
})
print(f"成功处理文件: {file_path}")
except Exception as e:
print(f"处理文件 {file_path} 时出错: {str(e)}")
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]}
]}
]
}
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
}
}
},
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
# 测试知识库功能
print("开始处理知识库文件...")
knowledge_library.process_files()
# 测试搜索功能
test_query = "麦麦评价一下僕と花"
print(f"\n搜索与'{test_query}'相似的内容:")
results = knowledge_library.search_similar_segments(test_query)
for result in results:
print(f"相似度: {result['similarity']:.4f}")
print(f"内容: {result['content'][:100]}...")
print("-" * 50)

View File

@@ -10,7 +10,6 @@ from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import networkx as nx import networkx as nx
import pymongo
from dotenv import load_dotenv from dotenv import load_dotenv
from loguru import logger from loguru import logger
import jieba import jieba

View File

@@ -0,0 +1,383 @@
import os
import sys
import time
import requests
from dotenv import load_dotenv
import hashlib
from datetime import datetime
from tqdm import tqdm
from rich.console import Console
from rich.table import Table
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
sys.path.append(root_path)
# 现在可以导入src模块
from src.common.database import Database
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env.prod")
if not os.path.exists(env_path):
raise FileNotFoundError(f"配置文件不存在: {env_path}")
load_dotenv(env_path)
class KnowledgeLibrary:
def __init__(self):
# 初始化数据库连接
if Database._instance is None:
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
self.db = Database.get_instance()
self.raw_info_dir = "data/raw_info"
self._ensure_dirs()
self.api_key = os.getenv("SILICONFLOW_KEY")
if not self.api_key:
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
self.console = Console()
def _ensure_dirs(self):
"""确保必要的目录存在"""
os.makedirs(self.raw_info_dir, exist_ok=True)
def read_file(self, file_path: str) -> str:
"""读取文件内容"""
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
def split_content(self, content: str, max_length: int = 512) -> list:
"""将内容分割成适当大小的块,保持段落完整性
Args:
content: 要分割的文本内容
max_length: 每个块的最大长度
Returns:
list: 分割后的文本块列表
"""
# 首先按段落分割
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
chunks = []
current_chunk = []
current_length = 0
for para in paragraphs:
para_length = len(para)
# 如果单个段落就超过最大长度
if para_length > max_length:
# 如果当前chunk不为空先保存
if current_chunk:
chunks.append('\n'.join(current_chunk))
current_chunk = []
current_length = 0
# 将长段落按句子分割
sentences = [s.strip() for s in para.replace('', '\n').replace('', '\n').replace('', '\n').split('\n') if s.strip()]
temp_chunk = []
temp_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if sentence_length > max_length:
# 如果单个句子超长,强制按长度分割
if temp_chunk:
chunks.append('\n'.join(temp_chunk))
temp_chunk = []
temp_length = 0
for i in range(0, len(sentence), max_length):
chunks.append(sentence[i:i + max_length])
elif temp_length + sentence_length + 1 <= max_length:
temp_chunk.append(sentence)
temp_length += sentence_length + 1
else:
chunks.append('\n'.join(temp_chunk))
temp_chunk = [sentence]
temp_length = sentence_length
if temp_chunk:
chunks.append('\n'.join(temp_chunk))
# 如果当前段落加上现有chunk不超过最大长度
elif current_length + para_length + 1 <= max_length:
current_chunk.append(para)
current_length += para_length + 1
else:
# 保存当前chunk并开始新的chunk
chunks.append('\n'.join(current_chunk))
current_chunk = [para]
current_length = para_length
# 添加最后一个chunk
if current_chunk:
chunks.append('\n'.join(current_chunk))
return chunks
def get_embedding(self, text: str) -> list:
"""获取文本的embedding向量"""
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
if response.status_code != 200:
print(f"获取embedding失败: {response.text}")
return None
return response.json()['data'][0]['embedding']
def process_files(self, knowledge_length:int=512):
"""处理raw_info目录下的所有txt文件"""
txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith('.txt')]
if not txt_files:
self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir))
self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]")
return
total_stats = {
"processed_files": 0,
"total_chunks": 0,
"failed_files": [],
"skipped_files": []
}
self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]")
for filename in tqdm(txt_files, desc="处理文件进度"):
file_path = os.path.join(self.raw_info_dir, filename)
result = self.process_single_file(file_path, knowledge_length)
self._update_stats(total_stats, result, filename)
self._display_processing_results(total_stats)
def process_single_file(self, file_path: str, knowledge_length: int = 512):
"""处理单个文件"""
result = {
"status": "success",
"chunks_processed": 0,
"error": None
}
try:
current_hash = self.calculate_file_hash(file_path)
processed_record = self.db.db.processed_files.find_one({"file_path": file_path})
if processed_record:
if processed_record.get("hash") == current_hash:
if knowledge_length in processed_record.get("split_by", []):
result["status"] = "skipped"
return result
content = self.read_file(file_path)
chunks = self.split_content(content, knowledge_length)
for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False):
embedding = self.get_embedding(chunk)
if embedding:
knowledge = {
"content": chunk,
"embedding": embedding,
"source_file": file_path,
"split_length": knowledge_length,
"created_at": datetime.now()
}
self.db.db.knowledges.insert_one(knowledge)
result["chunks_processed"] += 1
split_by = processed_record.get("split_by", []) if processed_record else []
if knowledge_length not in split_by:
split_by.append(knowledge_length)
self.db.db.processed_files.update_one(
{"file_path": file_path},
{
"$set": {
"hash": current_hash,
"last_processed": datetime.now(),
"split_by": split_by
}
},
upsert=True
)
except Exception as e:
result["status"] = "failed"
result["error"] = str(e)
return result
def _update_stats(self, total_stats, result, filename):
"""更新总体统计信息"""
if result["status"] == "success":
total_stats["processed_files"] += 1
total_stats["total_chunks"] += result["chunks_processed"]
elif result["status"] == "failed":
total_stats["failed_files"].append((filename, result["error"]))
elif result["status"] == "skipped":
total_stats["skipped_files"].append(filename)
def _display_processing_results(self, stats):
"""显示处理结果统计"""
self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]")
table = Table(show_header=True, header_style="bold magenta")
table.add_column("统计项", style="dim")
table.add_column("数值")
table.add_row("成功处理文件数", str(stats["processed_files"]))
table.add_row("处理的知识块总数", str(stats["total_chunks"]))
table.add_row("跳过的文件数", str(len(stats["skipped_files"])))
table.add_row("失败的文件数", str(len(stats["failed_files"])))
self.console.print(table)
if stats["failed_files"]:
self.console.print("\n[bold red]处理失败的文件:[/bold red]")
for filename, error in stats["failed_files"]:
self.console.print(f"[red]- {filename}: {error}[/red]")
if stats["skipped_files"]:
self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]")
for filename in stats["skipped_files"]:
self.console.print(f"[yellow]- {filename}[/yellow]")
def calculate_file_hash(self, file_path):
"""计算文件的MD5哈希值"""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def search_similar_segments(self, query: str, limit: int = 5) -> list:
"""搜索与查询文本相似的片段"""
query_embedding = self.get_embedding(query)
if not query_embedding:
return []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]}
]}
]
}
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
}
}
}
}
},
{
"$addFields": {
"similarity": {
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
}
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
]
results = list(self.db.db.knowledges.aggregate(pipeline))
return results
# 创建单例实例
knowledge_library = KnowledgeLibrary()
if __name__ == "__main__":
console = Console()
console.print("[bold green]知识库处理工具[/bold green]")
while True:
console.print("\n请选择要执行的操作:")
console.print("[1] 麦麦开始学习")
console.print("[2] 麦麦全部忘光光(仅知识)")
console.print("[q] 退出程序")
choice = input("\n请输入选项: ").strip()
if choice.lower() == 'q':
console.print("[yellow]程序退出[/yellow]")
sys.exit(0)
elif choice == '2':
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
if confirm == 'y':
knowledge_library.db.db.knowledges.delete_many({})
console.print("[green]已清空所有知识![/green]")
continue
elif choice == '1':
if not os.path.exists(knowledge_library.raw_info_dir):
console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]")
os.makedirs(knowledge_library.raw_info_dir, exist_ok=True)
# 询问分割长度
while True:
try:
length_input = input("请输入知识分割长度默认512输入q退出回车使用默认值: ").strip()
if length_input.lower() == 'q':
break
if not length_input: # 如果直接回车,使用默认值
knowledge_length = 512
break
knowledge_length = int(length_input)
if knowledge_length <= 0:
print("分割长度必须大于0请重新输入")
continue
break
except ValueError:
print("请输入有效的数字")
continue
if length_input.lower() == 'q':
continue
# 测试知识库功能
print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...")
knowledge_library.process_files(knowledge_length=knowledge_length)
else:
console.print("[red]无效的选项,请重新选择[/red]")
continue

45
麦麦开始学习.bat Normal file
View File

@@ -0,0 +1,45 @@
@echo off
setlocal enabledelayedexpansion
chcp 65001
cd /d %~dp0
echo =====================================
echo 选择Python环境:
echo 1 - venv (推荐)
echo 2 - conda
echo =====================================
choice /c 12 /n /m "输入数字(1或2): "
if errorlevel 2 (
echo =====================================
set "CONDA_ENV="
set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
:: 检查输入是否为空
if "!CONDA_ENV!"=="" (
echo 错误:环境名称不能为空
pause
exit /b 1
)
call conda activate !CONDA_ENV!
if errorlevel 1 (
echo 激活 conda 环境失败
pause
exit /b 1
)
echo Conda 环境 "!CONDA_ENV!" 激活成功
python src/plugins/zhishi/knowledge_library.py
) else (
if exist "venv\Scripts\python.exe" (
venv\Scripts\python src/plugins/zhishi/knowledge_library.py
) else (
echo =====================================
echo 错误: venv环境不存在请先创建虚拟环境
pause
exit /b 1
)
)
endlocal
pause