新增了知识库一键启动漂亮脚本
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
data/
|
||||
data1/
|
||||
mongodb/
|
||||
NapCat.Framework.Windows.Once/
|
||||
log/
|
||||
|
||||
12
run.py
12
run.py
@@ -128,13 +128,17 @@ if __name__ == "__main__":
|
||||
)
|
||||
os.system("cls")
|
||||
if choice == "1":
|
||||
install_napcat()
|
||||
install_mongodb()
|
||||
confirm = input("首次安装将下载并配置所需组件\n1.确认\n2.取消\n")
|
||||
if confirm == "1":
|
||||
install_napcat()
|
||||
install_mongodb()
|
||||
else:
|
||||
print("已取消安装")
|
||||
elif choice == "2":
|
||||
run_maimbot()
|
||||
choice = input("是否启动推理可视化?(y/N)").upper()
|
||||
choice = input("是否启动推理可视化?(未完善)(y/N)").upper()
|
||||
if choice == "Y":
|
||||
run_cmd(r"python src\gui\reasoning_gui.py")
|
||||
choice = input("是否启动记忆可视化?(y/N)").upper()
|
||||
choice = input("是否启动记忆可视化?(未完善)(y/N)").upper()
|
||||
if choice == "Y":
|
||||
run_cmd(r"python src/plugins/memory_system/memory_manual_build.py")
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from pymongo import MongoClient
|
||||
|
||||
|
||||
class Database:
|
||||
_instance: Optional["Database"] = None
|
||||
|
||||
@@ -50,25 +48,4 @@ class Database:
|
||||
def get_instance(cls) -> "Database":
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("Database not initialized")
|
||||
return cls._instance
|
||||
|
||||
|
||||
#测试用
|
||||
|
||||
def get_random_group_messages(self, group_id: str, limit: int = 5):
|
||||
# 先随机获取一条消息
|
||||
random_message = list(self.db.messages.aggregate([
|
||||
{"$match": {"group_id": group_id}},
|
||||
{"$sample": {"size": 1}}
|
||||
]))[0]
|
||||
|
||||
# 获取该消息之后的消息
|
||||
subsequent_messages = list(self.db.messages.find({
|
||||
"group_id": group_id,
|
||||
"time": {"$gt": random_message["time"]}
|
||||
}).sort("time", 1).limit(limit))
|
||||
|
||||
# 将随机消息和后续消息合并
|
||||
messages = [random_message] + subsequent_messages
|
||||
|
||||
return messages
|
||||
return cls._instance
|
||||
@@ -7,7 +7,6 @@ from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
|
||||
from ..memory_system.memory import hippocampus
|
||||
from ..moods.moods import MoodManager # 导入情绪管理器
|
||||
from .config import global_config
|
||||
from .cq_code import CQCode, cq_code_tool # 导入CQCode模块
|
||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||
from .llm_generator import ResponseGenerator
|
||||
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
@@ -218,7 +217,7 @@ class ChatBot:
|
||||
|
||||
# message_set 可以直接加入 message_manager
|
||||
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
|
||||
print(f"添加message_set到message_manager")
|
||||
print("添加message_set到message_manager")
|
||||
message_manager.add_message(message_set)
|
||||
|
||||
bot_response_time = thinking_time_point
|
||||
|
||||
@@ -8,7 +8,7 @@ from loguru import logger
|
||||
from ...common.database import Database
|
||||
from ..models.utils_model import LLM_request
|
||||
from .config import global_config
|
||||
from .message import MessageRecv, MessageThinking, MessageSending,Message
|
||||
from .message import MessageRecv, MessageThinking, Message
|
||||
from .prompt_builder import prompt_builder
|
||||
from .relationship_manager import relationship_manager
|
||||
from .utils import process_llm_response
|
||||
|
||||
@@ -3,14 +3,14 @@ import html
|
||||
import re
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, ForwardRef, List, Optional, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import urllib3
|
||||
from loguru import logger
|
||||
|
||||
from .utils_image import image_manager
|
||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
||||
from .chat_stream import ChatStream, chat_manager
|
||||
from .message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
from .chat_stream import ChatStream
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Optional, Union, Any, Dict
|
||||
from typing import List, Optional, Union, Dict
|
||||
|
||||
@dataclass
|
||||
class Seg:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, ForwardRef, List, Optional, Union
|
||||
from typing import Dict, Optional
|
||||
|
||||
import urllib3
|
||||
|
||||
from .cq_code import CQCode, cq_code_tool
|
||||
from .cq_code import cq_code_tool
|
||||
from .utils_cq import parse_cq_code
|
||||
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
|
||||
from .utils_user import get_groupname
|
||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
@@ -5,12 +5,10 @@ from typing import Dict, List, Optional, Union
|
||||
from loguru import logger
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
|
||||
from .cq_code import cq_code_tool
|
||||
from .message_cq import MessageSendCQ
|
||||
from .message import MessageSending, MessageThinking, MessageRecv,MessageSet
|
||||
from .message import MessageSending, MessageThinking, MessageSet
|
||||
from .storage import MessageStorage
|
||||
from .config import global_config
|
||||
from .chat_stream import chat_manager
|
||||
|
||||
|
||||
class Message_Sender:
|
||||
|
||||
@@ -9,7 +9,7 @@ from ..moods.moods import MoodManager
|
||||
from ..schedule.schedule_generator import bot_schedule
|
||||
from .config import global_config
|
||||
from .utils import get_embedding, get_recent_group_detailed_plain_text
|
||||
from .chat_stream import ChatStream, chat_manager
|
||||
from .chat_stream import chat_manager
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
|
||||
from ...common.database import Database
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...common.database import Database
|
||||
from .message_base import MessageBase
|
||||
from .message import MessageSending, MessageRecv
|
||||
from .chat_stream import ChatStream
|
||||
from loguru import logger
|
||||
|
||||
@@ -12,8 +12,8 @@ from loguru import logger
|
||||
from ..models.utils_model import LLM_request
|
||||
from ..utils.typo_generator import ChineseTypoGenerator
|
||||
from .config import global_config
|
||||
from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message
|
||||
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo
|
||||
from .message import MessageRecv,Message
|
||||
from .message_base import UserInfo
|
||||
from .chat_stream import ChatStream
|
||||
from ..moods.moods import MoodManager
|
||||
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import zlib
|
||||
import aiohttp
|
||||
import hashlib
|
||||
from typing import Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
from PIL import Image
|
||||
|
||||
from ...common.database import Database
|
||||
from ..chat.config import global_config
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from loguru import logger
|
||||
|
||||
from typing import Dict
|
||||
from loguru import logger
|
||||
|
||||
from .config import global_config
|
||||
from .message_base import UserInfo, GroupInfo
|
||||
from .chat_stream import chat_manager,ChatStream
|
||||
from .chat_stream import ChatStream
|
||||
|
||||
|
||||
class WillingManager:
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
# 加载根目录下的env.edv文件
|
||||
env_path = os.path.join(root_path, ".env.dev")
|
||||
if not os.path.exists(env_path):
|
||||
raise FileNotFoundError(f"配置文件不存在: {env_path}")
|
||||
load_dotenv(env_path)
|
||||
|
||||
from src.common.database import Database
|
||||
|
||||
# 从环境变量获取配置
|
||||
Database.initialize(
|
||||
uri=os.getenv("MONGODB_URI"),
|
||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
||||
)
|
||||
|
||||
class KnowledgeLibrary:
|
||||
def __init__(self):
|
||||
self.db = Database.get_instance()
|
||||
self.raw_info_dir = "data/raw_info"
|
||||
self._ensure_dirs()
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
|
||||
|
||||
def _ensure_dirs(self):
|
||||
"""确保必要的目录存在"""
|
||||
os.makedirs(self.raw_info_dir, exist_ok=True)
|
||||
|
||||
def get_embedding(self, text: str) -> list:
|
||||
"""获取文本的embedding向量"""
|
||||
url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
payload = {
|
||||
"model": "BAAI/bge-m3",
|
||||
"input": text,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"获取embedding失败: {response.text}")
|
||||
return None
|
||||
|
||||
return response.json()['data'][0]['embedding']
|
||||
|
||||
def process_files(self):
|
||||
"""处理raw_info目录下的所有txt文件"""
|
||||
for filename in os.listdir(self.raw_info_dir):
|
||||
if filename.endswith('.txt'):
|
||||
file_path = os.path.join(self.raw_info_dir, filename)
|
||||
self.process_single_file(file_path)
|
||||
|
||||
def process_single_file(self, file_path: str):
|
||||
"""处理单个文件"""
|
||||
try:
|
||||
# 检查文件是否已处理
|
||||
if self.db.db.processed_files.find_one({"file_path": file_path}):
|
||||
print(f"文件已处理过,跳过: {file_path}")
|
||||
return
|
||||
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# 按1024字符分段
|
||||
segments = [content[i:i+600] for i in range(0, len(content), 300)]
|
||||
|
||||
# 处理每个分段
|
||||
for segment in segments:
|
||||
if not segment.strip(): # 跳过空段
|
||||
continue
|
||||
|
||||
# 获取embedding
|
||||
embedding = self.get_embedding(segment)
|
||||
if not embedding:
|
||||
continue
|
||||
|
||||
# 存储到数据库
|
||||
doc = {
|
||||
"content": segment,
|
||||
"embedding": embedding,
|
||||
"file_path": file_path,
|
||||
"segment_length": len(segment)
|
||||
}
|
||||
|
||||
# 使用文本内容的哈希值作为唯一标识
|
||||
content_hash = hash(segment)
|
||||
|
||||
# 更新或插入文档
|
||||
self.db.db.knowledges.update_one(
|
||||
{"content_hash": content_hash},
|
||||
{"$set": doc},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
# 记录文件已处理
|
||||
self.db.db.processed_files.insert_one({
|
||||
"file_path": file_path,
|
||||
"processed_time": time.time()
|
||||
})
|
||||
|
||||
print(f"成功处理文件: {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理文件 {file_path} 时出错: {str(e)}")
|
||||
|
||||
def search_similar_segments(self, query: str, limit: int = 5) -> list:
|
||||
"""搜索与查询文本相似的片段"""
|
||||
query_embedding = self.get_embedding(query)
|
||||
if not query_embedding:
|
||||
return []
|
||||
|
||||
# 使用余弦相似度计算
|
||||
pipeline = [
|
||||
{
|
||||
"$addFields": {
|
||||
"dotProduct": {
|
||||
"$reduce": {
|
||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||||
"initialValue": 0,
|
||||
"in": {
|
||||
"$add": [
|
||||
"$$value",
|
||||
{"$multiply": [
|
||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||
{"$arrayElemAt": [query_embedding, "$$this"]}
|
||||
]}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"magnitude1": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": "$embedding",
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
||||
}
|
||||
}
|
||||
},
|
||||
"magnitude2": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": query_embedding,
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$addFields": {
|
||||
"similarity": {
|
||||
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$sort": {"similarity": -1}},
|
||||
{"$limit": limit},
|
||||
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
|
||||
]
|
||||
|
||||
results = list(self.db.db.knowledges.aggregate(pipeline))
|
||||
return results
|
||||
|
||||
# 创建单例实例
|
||||
knowledge_library = KnowledgeLibrary()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试知识库功能
|
||||
print("开始处理知识库文件...")
|
||||
knowledge_library.process_files()
|
||||
|
||||
# 测试搜索功能
|
||||
test_query = "麦麦评价一下僕と花"
|
||||
print(f"\n搜索与'{test_query}'相似的内容:")
|
||||
results = knowledge_library.search_similar_segments(test_query)
|
||||
for result in results:
|
||||
print(f"相似度: {result['similarity']:.4f}")
|
||||
print(f"内容: {result['content'][:100]}...")
|
||||
print("-" * 50)
|
||||
@@ -10,7 +10,6 @@ from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import pymongo
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
import jieba
|
||||
|
||||
383
src/plugins/zhishi/knowledge_library.py
Normal file
383
src/plugins/zhishi/knowledge_library.py
Normal file
@@ -0,0 +1,383 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from tqdm import tqdm
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
# 现在可以导入src模块
|
||||
from src.common.database import Database
|
||||
|
||||
# 加载根目录下的env.edv文件
|
||||
env_path = os.path.join(root_path, ".env.prod")
|
||||
if not os.path.exists(env_path):
|
||||
raise FileNotFoundError(f"配置文件不存在: {env_path}")
|
||||
load_dotenv(env_path)
|
||||
|
||||
class KnowledgeLibrary:
|
||||
def __init__(self):
|
||||
# 初始化数据库连接
|
||||
if Database._instance is None:
|
||||
Database.initialize(
|
||||
uri=os.getenv("MONGODB_URI"),
|
||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
||||
)
|
||||
self.db = Database.get_instance()
|
||||
self.raw_info_dir = "data/raw_info"
|
||||
self._ensure_dirs()
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
|
||||
self.console = Console()
|
||||
|
||||
def _ensure_dirs(self):
|
||||
"""确保必要的目录存在"""
|
||||
os.makedirs(self.raw_info_dir, exist_ok=True)
|
||||
|
||||
def read_file(self, file_path: str) -> str:
|
||||
"""读取文件内容"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
def split_content(self, content: str, max_length: int = 512) -> list:
|
||||
"""将内容分割成适当大小的块,保持段落完整性
|
||||
|
||||
Args:
|
||||
content: 要分割的文本内容
|
||||
max_length: 每个块的最大长度
|
||||
|
||||
Returns:
|
||||
list: 分割后的文本块列表
|
||||
"""
|
||||
# 首先按段落分割
|
||||
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
for para in paragraphs:
|
||||
para_length = len(para)
|
||||
|
||||
# 如果单个段落就超过最大长度
|
||||
if para_length > max_length:
|
||||
# 如果当前chunk不为空,先保存
|
||||
if current_chunk:
|
||||
chunks.append('\n'.join(current_chunk))
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
# 将长段落按句子分割
|
||||
sentences = [s.strip() for s in para.replace('。', '。\n').replace('!', '!\n').replace('?', '?\n').split('\n') if s.strip()]
|
||||
temp_chunk = []
|
||||
temp_length = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_length = len(sentence)
|
||||
if sentence_length > max_length:
|
||||
# 如果单个句子超长,强制按长度分割
|
||||
if temp_chunk:
|
||||
chunks.append('\n'.join(temp_chunk))
|
||||
temp_chunk = []
|
||||
temp_length = 0
|
||||
for i in range(0, len(sentence), max_length):
|
||||
chunks.append(sentence[i:i + max_length])
|
||||
elif temp_length + sentence_length + 1 <= max_length:
|
||||
temp_chunk.append(sentence)
|
||||
temp_length += sentence_length + 1
|
||||
else:
|
||||
chunks.append('\n'.join(temp_chunk))
|
||||
temp_chunk = [sentence]
|
||||
temp_length = sentence_length
|
||||
|
||||
if temp_chunk:
|
||||
chunks.append('\n'.join(temp_chunk))
|
||||
|
||||
# 如果当前段落加上现有chunk不超过最大长度
|
||||
elif current_length + para_length + 1 <= max_length:
|
||||
current_chunk.append(para)
|
||||
current_length += para_length + 1
|
||||
else:
|
||||
# 保存当前chunk并开始新的chunk
|
||||
chunks.append('\n'.join(current_chunk))
|
||||
current_chunk = [para]
|
||||
current_length = para_length
|
||||
|
||||
# 添加最后一个chunk
|
||||
if current_chunk:
|
||||
chunks.append('\n'.join(current_chunk))
|
||||
|
||||
return chunks
|
||||
|
||||
def get_embedding(self, text: str) -> list:
|
||||
"""获取文本的embedding向量"""
|
||||
url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
payload = {
|
||||
"model": "BAAI/bge-m3",
|
||||
"input": text,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
response = requests.post(url, json=payload, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"获取embedding失败: {response.text}")
|
||||
return None
|
||||
|
||||
return response.json()['data'][0]['embedding']
|
||||
|
||||
def process_files(self, knowledge_length:int=512):
|
||||
"""处理raw_info目录下的所有txt文件"""
|
||||
txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith('.txt')]
|
||||
|
||||
if not txt_files:
|
||||
self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir))
|
||||
self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]")
|
||||
return
|
||||
|
||||
total_stats = {
|
||||
"processed_files": 0,
|
||||
"total_chunks": 0,
|
||||
"failed_files": [],
|
||||
"skipped_files": []
|
||||
}
|
||||
|
||||
self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]")
|
||||
|
||||
for filename in tqdm(txt_files, desc="处理文件进度"):
|
||||
file_path = os.path.join(self.raw_info_dir, filename)
|
||||
result = self.process_single_file(file_path, knowledge_length)
|
||||
self._update_stats(total_stats, result, filename)
|
||||
|
||||
self._display_processing_results(total_stats)
|
||||
|
||||
def process_single_file(self, file_path: str, knowledge_length: int = 512):
|
||||
"""处理单个文件"""
|
||||
result = {
|
||||
"status": "success",
|
||||
"chunks_processed": 0,
|
||||
"error": None
|
||||
}
|
||||
|
||||
try:
|
||||
current_hash = self.calculate_file_hash(file_path)
|
||||
processed_record = self.db.db.processed_files.find_one({"file_path": file_path})
|
||||
|
||||
if processed_record:
|
||||
if processed_record.get("hash") == current_hash:
|
||||
if knowledge_length in processed_record.get("split_by", []):
|
||||
result["status"] = "skipped"
|
||||
return result
|
||||
|
||||
content = self.read_file(file_path)
|
||||
chunks = self.split_content(content, knowledge_length)
|
||||
|
||||
for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False):
|
||||
embedding = self.get_embedding(chunk)
|
||||
if embedding:
|
||||
knowledge = {
|
||||
"content": chunk,
|
||||
"embedding": embedding,
|
||||
"source_file": file_path,
|
||||
"split_length": knowledge_length,
|
||||
"created_at": datetime.now()
|
||||
}
|
||||
self.db.db.knowledges.insert_one(knowledge)
|
||||
result["chunks_processed"] += 1
|
||||
|
||||
split_by = processed_record.get("split_by", []) if processed_record else []
|
||||
if knowledge_length not in split_by:
|
||||
split_by.append(knowledge_length)
|
||||
|
||||
self.db.db.processed_files.update_one(
|
||||
{"file_path": file_path},
|
||||
{
|
||||
"$set": {
|
||||
"hash": current_hash,
|
||||
"last_processed": datetime.now(),
|
||||
"split_by": split_by
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result["status"] = "failed"
|
||||
result["error"] = str(e)
|
||||
|
||||
return result
|
||||
|
||||
def _update_stats(self, total_stats, result, filename):
|
||||
"""更新总体统计信息"""
|
||||
if result["status"] == "success":
|
||||
total_stats["processed_files"] += 1
|
||||
total_stats["total_chunks"] += result["chunks_processed"]
|
||||
elif result["status"] == "failed":
|
||||
total_stats["failed_files"].append((filename, result["error"]))
|
||||
elif result["status"] == "skipped":
|
||||
total_stats["skipped_files"].append(filename)
|
||||
|
||||
def _display_processing_results(self, stats):
|
||||
"""显示处理结果统计"""
|
||||
self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]")
|
||||
|
||||
table = Table(show_header=True, header_style="bold magenta")
|
||||
table.add_column("统计项", style="dim")
|
||||
table.add_column("数值")
|
||||
|
||||
table.add_row("成功处理文件数", str(stats["processed_files"]))
|
||||
table.add_row("处理的知识块总数", str(stats["total_chunks"]))
|
||||
table.add_row("跳过的文件数", str(len(stats["skipped_files"])))
|
||||
table.add_row("失败的文件数", str(len(stats["failed_files"])))
|
||||
|
||||
self.console.print(table)
|
||||
|
||||
if stats["failed_files"]:
|
||||
self.console.print("\n[bold red]处理失败的文件:[/bold red]")
|
||||
for filename, error in stats["failed_files"]:
|
||||
self.console.print(f"[red]- {filename}: {error}[/red]")
|
||||
|
||||
if stats["skipped_files"]:
|
||||
self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]")
|
||||
for filename in stats["skipped_files"]:
|
||||
self.console.print(f"[yellow]- {filename}[/yellow]")
|
||||
|
||||
def calculate_file_hash(self, file_path):
|
||||
"""计算文件的MD5哈希值"""
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
def search_similar_segments(self, query: str, limit: int = 5) -> list:
|
||||
"""搜索与查询文本相似的片段"""
|
||||
query_embedding = self.get_embedding(query)
|
||||
if not query_embedding:
|
||||
return []
|
||||
|
||||
# 使用余弦相似度计算
|
||||
pipeline = [
|
||||
{
|
||||
"$addFields": {
|
||||
"dotProduct": {
|
||||
"$reduce": {
|
||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||||
"initialValue": 0,
|
||||
"in": {
|
||||
"$add": [
|
||||
"$$value",
|
||||
{"$multiply": [
|
||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||
{"$arrayElemAt": [query_embedding, "$$this"]}
|
||||
]}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"magnitude1": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": "$embedding",
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
||||
}
|
||||
}
|
||||
},
|
||||
"magnitude2": {
|
||||
"$sqrt": {
|
||||
"$reduce": {
|
||||
"input": query_embedding,
|
||||
"initialValue": 0,
|
||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$addFields": {
|
||||
"similarity": {
|
||||
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
|
||||
}
|
||||
}
|
||||
},
|
||||
{"$sort": {"similarity": -1}},
|
||||
{"$limit": limit},
|
||||
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
|
||||
]
|
||||
|
||||
results = list(self.db.db.knowledges.aggregate(pipeline))
|
||||
return results
|
||||
|
||||
# 创建单例实例
|
||||
knowledge_library = KnowledgeLibrary()
|
||||
|
||||
if __name__ == "__main__":
|
||||
console = Console()
|
||||
console.print("[bold green]知识库处理工具[/bold green]")
|
||||
|
||||
while True:
|
||||
console.print("\n请选择要执行的操作:")
|
||||
console.print("[1] 麦麦开始学习")
|
||||
console.print("[2] 麦麦全部忘光光(仅知识)")
|
||||
console.print("[q] 退出程序")
|
||||
|
||||
choice = input("\n请输入选项: ").strip()
|
||||
|
||||
if choice.lower() == 'q':
|
||||
console.print("[yellow]程序退出[/yellow]")
|
||||
sys.exit(0)
|
||||
elif choice == '2':
|
||||
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
|
||||
if confirm == 'y':
|
||||
knowledge_library.db.db.knowledges.delete_many({})
|
||||
console.print("[green]已清空所有知识![/green]")
|
||||
continue
|
||||
elif choice == '1':
|
||||
if not os.path.exists(knowledge_library.raw_info_dir):
|
||||
console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]")
|
||||
os.makedirs(knowledge_library.raw_info_dir, exist_ok=True)
|
||||
|
||||
# 询问分割长度
|
||||
while True:
|
||||
try:
|
||||
length_input = input("请输入知识分割长度(默认512,输入q退出,回车使用默认值): ").strip()
|
||||
if length_input.lower() == 'q':
|
||||
break
|
||||
if not length_input: # 如果直接回车,使用默认值
|
||||
knowledge_length = 512
|
||||
break
|
||||
knowledge_length = int(length_input)
|
||||
if knowledge_length <= 0:
|
||||
print("分割长度必须大于0,请重新输入")
|
||||
continue
|
||||
break
|
||||
except ValueError:
|
||||
print("请输入有效的数字")
|
||||
continue
|
||||
|
||||
if length_input.lower() == 'q':
|
||||
continue
|
||||
|
||||
# 测试知识库功能
|
||||
print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...")
|
||||
knowledge_library.process_files(knowledge_length=knowledge_length)
|
||||
else:
|
||||
console.print("[red]无效的选项,请重新选择[/red]")
|
||||
continue
|
||||
45
麦麦开始学习.bat
Normal file
45
麦麦开始学习.bat
Normal file
@@ -0,0 +1,45 @@
|
||||
@echo off
|
||||
setlocal enabledelayedexpansion
|
||||
chcp 65001
|
||||
cd /d %~dp0
|
||||
|
||||
echo =====================================
|
||||
echo 选择Python环境:
|
||||
echo 1 - venv (推荐)
|
||||
echo 2 - conda
|
||||
echo =====================================
|
||||
choice /c 12 /n /m "输入数字(1或2): "
|
||||
|
||||
if errorlevel 2 (
|
||||
echo =====================================
|
||||
set "CONDA_ENV="
|
||||
set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
|
||||
|
||||
:: 检查输入是否为空
|
||||
if "!CONDA_ENV!"=="" (
|
||||
echo 错误:环境名称不能为空
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
call conda activate !CONDA_ENV!
|
||||
if errorlevel 1 (
|
||||
echo 激活 conda 环境失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
echo Conda 环境 "!CONDA_ENV!" 激活成功
|
||||
python src/plugins/zhishi/knowledge_library.py
|
||||
) else (
|
||||
if exist "venv\Scripts\python.exe" (
|
||||
venv\Scripts\python src/plugins/zhishi/knowledge_library.py
|
||||
) else (
|
||||
echo =====================================
|
||||
echo 错误: venv环境不存在,请先创建虚拟环境
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
endlocal
|
||||
pause
|
||||
Reference in New Issue
Block a user