新增了知识库一键启动漂亮脚本
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
data/
|
data/
|
||||||
|
data1/
|
||||||
mongodb/
|
mongodb/
|
||||||
NapCat.Framework.Windows.Once/
|
NapCat.Framework.Windows.Once/
|
||||||
log/
|
log/
|
||||||
|
|||||||
12
run.py
12
run.py
@@ -128,13 +128,17 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
os.system("cls")
|
os.system("cls")
|
||||||
if choice == "1":
|
if choice == "1":
|
||||||
install_napcat()
|
confirm = input("首次安装将下载并配置所需组件\n1.确认\n2.取消\n")
|
||||||
install_mongodb()
|
if confirm == "1":
|
||||||
|
install_napcat()
|
||||||
|
install_mongodb()
|
||||||
|
else:
|
||||||
|
print("已取消安装")
|
||||||
elif choice == "2":
|
elif choice == "2":
|
||||||
run_maimbot()
|
run_maimbot()
|
||||||
choice = input("是否启动推理可视化?(y/N)").upper()
|
choice = input("是否启动推理可视化?(未完善)(y/N)").upper()
|
||||||
if choice == "Y":
|
if choice == "Y":
|
||||||
run_cmd(r"python src\gui\reasoning_gui.py")
|
run_cmd(r"python src\gui\reasoning_gui.py")
|
||||||
choice = input("是否启动记忆可视化?(y/N)").upper()
|
choice = input("是否启动记忆可视化?(未完善)(y/N)").upper()
|
||||||
if choice == "Y":
|
if choice == "Y":
|
||||||
run_cmd(r"python src/plugins/memory_system/memory_manual_build.py")
|
run_cmd(r"python src/plugins/memory_system/memory_manual_build.py")
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
_instance: Optional["Database"] = None
|
_instance: Optional["Database"] = None
|
||||||
|
|
||||||
@@ -50,25 +48,4 @@ class Database:
|
|||||||
def get_instance(cls) -> "Database":
|
def get_instance(cls) -> "Database":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
raise RuntimeError("Database not initialized")
|
raise RuntimeError("Database not initialized")
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
|
|
||||||
#测试用
|
|
||||||
|
|
||||||
def get_random_group_messages(self, group_id: str, limit: int = 5):
|
|
||||||
# 先随机获取一条消息
|
|
||||||
random_message = list(self.db.messages.aggregate([
|
|
||||||
{"$match": {"group_id": group_id}},
|
|
||||||
{"$sample": {"size": 1}}
|
|
||||||
]))[0]
|
|
||||||
|
|
||||||
# 获取该消息之后的消息
|
|
||||||
subsequent_messages = list(self.db.messages.find({
|
|
||||||
"group_id": group_id,
|
|
||||||
"time": {"$gt": random_message["time"]}
|
|
||||||
}).sort("time", 1).limit(limit))
|
|
||||||
|
|
||||||
# 将随机消息和后续消息合并
|
|
||||||
messages = [random_message] + subsequent_messages
|
|
||||||
|
|
||||||
return messages
|
|
||||||
@@ -7,7 +7,6 @@ from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
|
|||||||
from ..memory_system.memory import hippocampus
|
from ..memory_system.memory import hippocampus
|
||||||
from ..moods.moods import MoodManager # 导入情绪管理器
|
from ..moods.moods import MoodManager # 导入情绪管理器
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .cq_code import CQCode, cq_code_tool # 导入CQCode模块
|
|
||||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||||
from .llm_generator import ResponseGenerator
|
from .llm_generator import ResponseGenerator
|
||||||
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||||
@@ -218,7 +217,7 @@ class ChatBot:
|
|||||||
|
|
||||||
# message_set 可以直接加入 message_manager
|
# message_set 可以直接加入 message_manager
|
||||||
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
|
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
|
||||||
print(f"添加message_set到message_manager")
|
print("添加message_set到message_manager")
|
||||||
message_manager.add_message(message_set)
|
message_manager.add_message(message_set)
|
||||||
|
|
||||||
bot_response_time = thinking_time_point
|
bot_response_time = thinking_time_point
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from loguru import logger
|
|||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .message import MessageRecv, MessageThinking, MessageSending,Message
|
from .message import MessageRecv, MessageThinking, Message
|
||||||
from .prompt_builder import prompt_builder
|
from .prompt_builder import prompt_builder
|
||||||
from .relationship_manager import relationship_manager
|
from .relationship_manager import relationship_manager
|
||||||
from .utils import process_llm_response
|
from .utils import process_llm_response
|
||||||
|
|||||||
@@ -3,14 +3,14 @@ import html
|
|||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, ForwardRef, List, Optional, Union
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .utils_image import image_manager
|
from .utils_image import image_manager
|
||||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
from .message_base import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
from .chat_stream import ChatStream, chat_manager
|
from .chat_stream import ChatStream
|
||||||
# 禁用SSL警告
|
# 禁用SSL警告
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
from typing import List, Optional, Union, Any, Dict
|
from typing import List, Optional, Union, Dict
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Seg:
|
class Seg:
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, ForwardRef, List, Optional, Union
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import urllib3
|
import urllib3
|
||||||
|
|
||||||
from .cq_code import CQCode, cq_code_tool
|
from .cq_code import cq_code_tool
|
||||||
from .utils_cq import parse_cq_code
|
from .utils_cq import parse_cq_code
|
||||||
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
|
from .utils_user import get_groupname
|
||||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
||||||
# 禁用SSL警告
|
# 禁用SSL警告
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|||||||
@@ -5,12 +5,10 @@ from typing import Dict, List, Optional, Union
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from nonebot.adapters.onebot.v11 import Bot
|
from nonebot.adapters.onebot.v11 import Bot
|
||||||
|
|
||||||
from .cq_code import cq_code_tool
|
|
||||||
from .message_cq import MessageSendCQ
|
from .message_cq import MessageSendCQ
|
||||||
from .message import MessageSending, MessageThinking, MessageRecv,MessageSet
|
from .message import MessageSending, MessageThinking, MessageSet
|
||||||
from .storage import MessageStorage
|
from .storage import MessageStorage
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .chat_stream import chat_manager
|
|
||||||
|
|
||||||
|
|
||||||
class Message_Sender:
|
class Message_Sender:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from ..moods.moods import MoodManager
|
|||||||
from ..schedule.schedule_generator import bot_schedule
|
from ..schedule.schedule_generator import bot_schedule
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .utils import get_embedding, get_recent_group_detailed_plain_text
|
from .utils import get_embedding, get_recent_group_detailed_plain_text
|
||||||
from .chat_stream import ChatStream, chat_manager
|
from .chat_stream import chat_manager
|
||||||
|
|
||||||
|
|
||||||
class PromptBuilder:
|
class PromptBuilder:
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
from typing import Optional, Union
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from .message_base import MessageBase
|
|
||||||
from .message import MessageSending, MessageRecv
|
from .message import MessageSending, MessageRecv
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from loguru import logger
|
|||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from ..utils.typo_generator import ChineseTypoGenerator
|
from ..utils.typo_generator import ChineseTypoGenerator
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message
|
from .message import MessageRecv,Message
|
||||||
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo
|
from .message_base import UserInfo
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from ..moods.moods import MoodManager
|
from ..moods.moods import MoodManager
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,12 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import zlib
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Union
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from typing import Dict
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .message_base import UserInfo, GroupInfo
|
from .chat_stream import ChatStream
|
||||||
from .chat_stream import chat_manager,ChatStream
|
|
||||||
|
|
||||||
|
|
||||||
class WillingManager:
|
class WillingManager:
|
||||||
|
|||||||
@@ -1,199 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
# 添加项目根目录到 Python 路径
|
|
||||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
|
||||||
sys.path.append(root_path)
|
|
||||||
|
|
||||||
# 加载根目录下的env.edv文件
|
|
||||||
env_path = os.path.join(root_path, ".env.dev")
|
|
||||||
if not os.path.exists(env_path):
|
|
||||||
raise FileNotFoundError(f"配置文件不存在: {env_path}")
|
|
||||||
load_dotenv(env_path)
|
|
||||||
|
|
||||||
from src.common.database import Database
|
|
||||||
|
|
||||||
# 从环境变量获取配置
|
|
||||||
Database.initialize(
|
|
||||||
uri=os.getenv("MONGODB_URI"),
|
|
||||||
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
|
||||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
|
||||||
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
|
||||||
username=os.getenv("MONGODB_USERNAME"),
|
|
||||||
password=os.getenv("MONGODB_PASSWORD"),
|
|
||||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
|
||||||
)
|
|
||||||
|
|
||||||
class KnowledgeLibrary:
|
|
||||||
def __init__(self):
|
|
||||||
self.db = Database.get_instance()
|
|
||||||
self.raw_info_dir = "data/raw_info"
|
|
||||||
self._ensure_dirs()
|
|
||||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
|
|
||||||
|
|
||||||
def _ensure_dirs(self):
|
|
||||||
"""确保必要的目录存在"""
|
|
||||||
os.makedirs(self.raw_info_dir, exist_ok=True)
|
|
||||||
|
|
||||||
def get_embedding(self, text: str) -> list:
|
|
||||||
"""获取文本的embedding向量"""
|
|
||||||
url = "https://api.siliconflow.cn/v1/embeddings"
|
|
||||||
payload = {
|
|
||||||
"model": "BAAI/bge-m3",
|
|
||||||
"input": text,
|
|
||||||
"encoding_format": "float"
|
|
||||||
}
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(url, json=payload, headers=headers)
|
|
||||||
if response.status_code != 200:
|
|
||||||
print(f"获取embedding失败: {response.text}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return response.json()['data'][0]['embedding']
|
|
||||||
|
|
||||||
def process_files(self):
|
|
||||||
"""处理raw_info目录下的所有txt文件"""
|
|
||||||
for filename in os.listdir(self.raw_info_dir):
|
|
||||||
if filename.endswith('.txt'):
|
|
||||||
file_path = os.path.join(self.raw_info_dir, filename)
|
|
||||||
self.process_single_file(file_path)
|
|
||||||
|
|
||||||
def process_single_file(self, file_path: str):
|
|
||||||
"""处理单个文件"""
|
|
||||||
try:
|
|
||||||
# 检查文件是否已处理
|
|
||||||
if self.db.db.processed_files.find_one({"file_path": file_path}):
|
|
||||||
print(f"文件已处理过,跳过: {file_path}")
|
|
||||||
return
|
|
||||||
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# 按1024字符分段
|
|
||||||
segments = [content[i:i+600] for i in range(0, len(content), 300)]
|
|
||||||
|
|
||||||
# 处理每个分段
|
|
||||||
for segment in segments:
|
|
||||||
if not segment.strip(): # 跳过空段
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取embedding
|
|
||||||
embedding = self.get_embedding(segment)
|
|
||||||
if not embedding:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 存储到数据库
|
|
||||||
doc = {
|
|
||||||
"content": segment,
|
|
||||||
"embedding": embedding,
|
|
||||||
"file_path": file_path,
|
|
||||||
"segment_length": len(segment)
|
|
||||||
}
|
|
||||||
|
|
||||||
# 使用文本内容的哈希值作为唯一标识
|
|
||||||
content_hash = hash(segment)
|
|
||||||
|
|
||||||
# 更新或插入文档
|
|
||||||
self.db.db.knowledges.update_one(
|
|
||||||
{"content_hash": content_hash},
|
|
||||||
{"$set": doc},
|
|
||||||
upsert=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 记录文件已处理
|
|
||||||
self.db.db.processed_files.insert_one({
|
|
||||||
"file_path": file_path,
|
|
||||||
"processed_time": time.time()
|
|
||||||
})
|
|
||||||
|
|
||||||
print(f"成功处理文件: {file_path}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"处理文件 {file_path} 时出错: {str(e)}")
|
|
||||||
|
|
||||||
def search_similar_segments(self, query: str, limit: int = 5) -> list:
|
|
||||||
"""搜索与查询文本相似的片段"""
|
|
||||||
query_embedding = self.get_embedding(query)
|
|
||||||
if not query_embedding:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 使用余弦相似度计算
|
|
||||||
pipeline = [
|
|
||||||
{
|
|
||||||
"$addFields": {
|
|
||||||
"dotProduct": {
|
|
||||||
"$reduce": {
|
|
||||||
"input": {"$range": [0, {"$size": "$embedding"}]},
|
|
||||||
"initialValue": 0,
|
|
||||||
"in": {
|
|
||||||
"$add": [
|
|
||||||
"$$value",
|
|
||||||
{"$multiply": [
|
|
||||||
{"$arrayElemAt": ["$embedding", "$$this"]},
|
|
||||||
{"$arrayElemAt": [query_embedding, "$$this"]}
|
|
||||||
]}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"magnitude1": {
|
|
||||||
"$sqrt": {
|
|
||||||
"$reduce": {
|
|
||||||
"input": "$embedding",
|
|
||||||
"initialValue": 0,
|
|
||||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"magnitude2": {
|
|
||||||
"$sqrt": {
|
|
||||||
"$reduce": {
|
|
||||||
"input": query_embedding,
|
|
||||||
"initialValue": 0,
|
|
||||||
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"$addFields": {
|
|
||||||
"similarity": {
|
|
||||||
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$sort": {"similarity": -1}},
|
|
||||||
{"$limit": limit},
|
|
||||||
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
|
|
||||||
]
|
|
||||||
|
|
||||||
results = list(self.db.db.knowledges.aggregate(pipeline))
|
|
||||||
return results
|
|
||||||
|
|
||||||
# 创建单例实例
|
|
||||||
knowledge_library = KnowledgeLibrary()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 测试知识库功能
|
|
||||||
print("开始处理知识库文件...")
|
|
||||||
knowledge_library.process_files()
|
|
||||||
|
|
||||||
# 测试搜索功能
|
|
||||||
test_query = "麦麦评价一下僕と花"
|
|
||||||
print(f"\n搜索与'{test_query}'相似的内容:")
|
|
||||||
results = knowledge_library.search_similar_segments(test_query)
|
|
||||||
for result in results:
|
|
||||||
print(f"相似度: {result['similarity']:.4f}")
|
|
||||||
print(f"内容: {result['content'][:100]}...")
|
|
||||||
print("-" * 50)
|
|
||||||
@@ -10,7 +10,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import pymongo
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import jieba
|
import jieba
|
||||||
|
|||||||
383
src/plugins/zhishi/knowledge_library.py
Normal file
383
src/plugins/zhishi/knowledge_library.py
Normal file
@@ -0,0 +1,383 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import hashlib
|
||||||
|
from datetime import datetime
|
||||||
|
from tqdm import tqdm
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
# 添加项目根目录到 Python 路径
|
||||||
|
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
|
sys.path.append(root_path)
|
||||||
|
|
||||||
|
# 现在可以导入src模块
|
||||||
|
from src.common.database import Database
|
||||||
|
|
||||||
|
# 加载根目录下的env.edv文件
|
||||||
|
env_path = os.path.join(root_path, ".env.prod")
|
||||||
|
if not os.path.exists(env_path):
|
||||||
|
raise FileNotFoundError(f"配置文件不存在: {env_path}")
|
||||||
|
load_dotenv(env_path)
|
||||||
|
|
||||||
|
class KnowledgeLibrary:
|
||||||
|
def __init__(self):
|
||||||
|
# 初始化数据库连接
|
||||||
|
if Database._instance is None:
|
||||||
|
Database.initialize(
|
||||||
|
uri=os.getenv("MONGODB_URI"),
|
||||||
|
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
|
||||||
|
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||||
|
db_name=os.getenv("DATABASE_NAME", "MegBot"),
|
||||||
|
username=os.getenv("MONGODB_USERNAME"),
|
||||||
|
password=os.getenv("MONGODB_PASSWORD"),
|
||||||
|
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
|
||||||
|
)
|
||||||
|
self.db = Database.get_instance()
|
||||||
|
self.raw_info_dir = "data/raw_info"
|
||||||
|
self._ensure_dirs()
|
||||||
|
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
|
||||||
|
self.console = Console()
|
||||||
|
|
||||||
|
def _ensure_dirs(self):
|
||||||
|
"""确保必要的目录存在"""
|
||||||
|
os.makedirs(self.raw_info_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def read_file(self, file_path: str) -> str:
|
||||||
|
"""读取文件内容"""
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def split_content(self, content: str, max_length: int = 512) -> list:
|
||||||
|
"""将内容分割成适当大小的块,保持段落完整性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 要分割的文本内容
|
||||||
|
max_length: 每个块的最大长度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 分割后的文本块列表
|
||||||
|
"""
|
||||||
|
# 首先按段落分割
|
||||||
|
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
|
||||||
|
chunks = []
|
||||||
|
current_chunk = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
for para in paragraphs:
|
||||||
|
para_length = len(para)
|
||||||
|
|
||||||
|
# 如果单个段落就超过最大长度
|
||||||
|
if para_length > max_length:
|
||||||
|
# 如果当前chunk不为空,先保存
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append('\n'.join(current_chunk))
|
||||||
|
current_chunk = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
# 将长段落按句子分割
|
||||||
|
sentences = [s.strip() for s in para.replace('。', '。\n').replace('!', '!\n').replace('?', '?\n').split('\n') if s.strip()]
|
||||||
|
temp_chunk = []
|
||||||
|
temp_length = 0
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_length = len(sentence)
|
||||||
|
if sentence_length > max_length:
|
||||||
|
# 如果单个句子超长,强制按长度分割
|
||||||
|
if temp_chunk:
|
||||||
|
chunks.append('\n'.join(temp_chunk))
|
||||||
|
temp_chunk = []
|
||||||
|
temp_length = 0
|
||||||
|
for i in range(0, len(sentence), max_length):
|
||||||
|
chunks.append(sentence[i:i + max_length])
|
||||||
|
elif temp_length + sentence_length + 1 <= max_length:
|
||||||
|
temp_chunk.append(sentence)
|
||||||
|
temp_length += sentence_length + 1
|
||||||
|
else:
|
||||||
|
chunks.append('\n'.join(temp_chunk))
|
||||||
|
temp_chunk = [sentence]
|
||||||
|
temp_length = sentence_length
|
||||||
|
|
||||||
|
if temp_chunk:
|
||||||
|
chunks.append('\n'.join(temp_chunk))
|
||||||
|
|
||||||
|
# 如果当前段落加上现有chunk不超过最大长度
|
||||||
|
elif current_length + para_length + 1 <= max_length:
|
||||||
|
current_chunk.append(para)
|
||||||
|
current_length += para_length + 1
|
||||||
|
else:
|
||||||
|
# 保存当前chunk并开始新的chunk
|
||||||
|
chunks.append('\n'.join(current_chunk))
|
||||||
|
current_chunk = [para]
|
||||||
|
current_length = para_length
|
||||||
|
|
||||||
|
# 添加最后一个chunk
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append('\n'.join(current_chunk))
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def get_embedding(self, text: str) -> list:
|
||||||
|
"""获取文本的embedding向量"""
|
||||||
|
url = "https://api.siliconflow.cn/v1/embeddings"
|
||||||
|
payload = {
|
||||||
|
"model": "BAAI/bge-m3",
|
||||||
|
"input": text,
|
||||||
|
"encoding_format": "float"
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, json=payload, headers=headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"获取embedding失败: {response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return response.json()['data'][0]['embedding']
|
||||||
|
|
||||||
|
def process_files(self, knowledge_length:int=512):
|
||||||
|
"""处理raw_info目录下的所有txt文件"""
|
||||||
|
txt_files = [f for f in os.listdir(self.raw_info_dir) if f.endswith('.txt')]
|
||||||
|
|
||||||
|
if not txt_files:
|
||||||
|
self.console.print("[red]警告:在 {} 目录下没有找到任何txt文件[/red]".format(self.raw_info_dir))
|
||||||
|
self.console.print("[yellow]请将需要处理的文本文件放入该目录后再运行程序[/yellow]")
|
||||||
|
return
|
||||||
|
|
||||||
|
total_stats = {
|
||||||
|
"processed_files": 0,
|
||||||
|
"total_chunks": 0,
|
||||||
|
"failed_files": [],
|
||||||
|
"skipped_files": []
|
||||||
|
}
|
||||||
|
|
||||||
|
self.console.print(f"\n[bold blue]开始处理知识库文件 - 共{len(txt_files)}个文件[/bold blue]")
|
||||||
|
|
||||||
|
for filename in tqdm(txt_files, desc="处理文件进度"):
|
||||||
|
file_path = os.path.join(self.raw_info_dir, filename)
|
||||||
|
result = self.process_single_file(file_path, knowledge_length)
|
||||||
|
self._update_stats(total_stats, result, filename)
|
||||||
|
|
||||||
|
self._display_processing_results(total_stats)
|
||||||
|
|
||||||
|
def process_single_file(self, file_path: str, knowledge_length: int = 512):
|
||||||
|
"""处理单个文件"""
|
||||||
|
result = {
|
||||||
|
"status": "success",
|
||||||
|
"chunks_processed": 0,
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_hash = self.calculate_file_hash(file_path)
|
||||||
|
processed_record = self.db.db.processed_files.find_one({"file_path": file_path})
|
||||||
|
|
||||||
|
if processed_record:
|
||||||
|
if processed_record.get("hash") == current_hash:
|
||||||
|
if knowledge_length in processed_record.get("split_by", []):
|
||||||
|
result["status"] = "skipped"
|
||||||
|
return result
|
||||||
|
|
||||||
|
content = self.read_file(file_path)
|
||||||
|
chunks = self.split_content(content, knowledge_length)
|
||||||
|
|
||||||
|
for chunk in tqdm(chunks, desc=f"处理 {os.path.basename(file_path)} 的文本块", leave=False):
|
||||||
|
embedding = self.get_embedding(chunk)
|
||||||
|
if embedding:
|
||||||
|
knowledge = {
|
||||||
|
"content": chunk,
|
||||||
|
"embedding": embedding,
|
||||||
|
"source_file": file_path,
|
||||||
|
"split_length": knowledge_length,
|
||||||
|
"created_at": datetime.now()
|
||||||
|
}
|
||||||
|
self.db.db.knowledges.insert_one(knowledge)
|
||||||
|
result["chunks_processed"] += 1
|
||||||
|
|
||||||
|
split_by = processed_record.get("split_by", []) if processed_record else []
|
||||||
|
if knowledge_length not in split_by:
|
||||||
|
split_by.append(knowledge_length)
|
||||||
|
|
||||||
|
self.db.db.processed_files.update_one(
|
||||||
|
{"file_path": file_path},
|
||||||
|
{
|
||||||
|
"$set": {
|
||||||
|
"hash": current_hash,
|
||||||
|
"last_processed": datetime.now(),
|
||||||
|
"split_by": split_by
|
||||||
|
}
|
||||||
|
},
|
||||||
|
upsert=True
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result["status"] = "failed"
|
||||||
|
result["error"] = str(e)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _update_stats(self, total_stats, result, filename):
|
||||||
|
"""更新总体统计信息"""
|
||||||
|
if result["status"] == "success":
|
||||||
|
total_stats["processed_files"] += 1
|
||||||
|
total_stats["total_chunks"] += result["chunks_processed"]
|
||||||
|
elif result["status"] == "failed":
|
||||||
|
total_stats["failed_files"].append((filename, result["error"]))
|
||||||
|
elif result["status"] == "skipped":
|
||||||
|
total_stats["skipped_files"].append(filename)
|
||||||
|
|
||||||
|
def _display_processing_results(self, stats):
|
||||||
|
"""显示处理结果统计"""
|
||||||
|
self.console.print("\n[bold green]处理完成!统计信息如下:[/bold green]")
|
||||||
|
|
||||||
|
table = Table(show_header=True, header_style="bold magenta")
|
||||||
|
table.add_column("统计项", style="dim")
|
||||||
|
table.add_column("数值")
|
||||||
|
|
||||||
|
table.add_row("成功处理文件数", str(stats["processed_files"]))
|
||||||
|
table.add_row("处理的知识块总数", str(stats["total_chunks"]))
|
||||||
|
table.add_row("跳过的文件数", str(len(stats["skipped_files"])))
|
||||||
|
table.add_row("失败的文件数", str(len(stats["failed_files"])))
|
||||||
|
|
||||||
|
self.console.print(table)
|
||||||
|
|
||||||
|
if stats["failed_files"]:
|
||||||
|
self.console.print("\n[bold red]处理失败的文件:[/bold red]")
|
||||||
|
for filename, error in stats["failed_files"]:
|
||||||
|
self.console.print(f"[red]- {filename}: {error}[/red]")
|
||||||
|
|
||||||
|
if stats["skipped_files"]:
|
||||||
|
self.console.print("\n[bold yellow]跳过的文件(已处理):[/bold yellow]")
|
||||||
|
for filename in stats["skipped_files"]:
|
||||||
|
self.console.print(f"[yellow]- {filename}[/yellow]")
|
||||||
|
|
||||||
|
def calculate_file_hash(self, file_path):
|
||||||
|
"""计算文件的MD5哈希值"""
|
||||||
|
hash_md5 = hashlib.md5()
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
for chunk in iter(lambda: f.read(4096), b""):
|
||||||
|
hash_md5.update(chunk)
|
||||||
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
def search_similar_segments(self, query: str, limit: int = 5) -> list:
|
||||||
|
"""搜索与查询文本相似的片段"""
|
||||||
|
query_embedding = self.get_embedding(query)
|
||||||
|
if not query_embedding:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 使用余弦相似度计算
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$addFields": {
|
||||||
|
"dotProduct": {
|
||||||
|
"$reduce": {
|
||||||
|
"input": {"$range": [0, {"$size": "$embedding"}]},
|
||||||
|
"initialValue": 0,
|
||||||
|
"in": {
|
||||||
|
"$add": [
|
||||||
|
"$$value",
|
||||||
|
{"$multiply": [
|
||||||
|
{"$arrayElemAt": ["$embedding", "$$this"]},
|
||||||
|
{"$arrayElemAt": [query_embedding, "$$this"]}
|
||||||
|
]}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"magnitude1": {
|
||||||
|
"$sqrt": {
|
||||||
|
"$reduce": {
|
||||||
|
"input": "$embedding",
|
||||||
|
"initialValue": 0,
|
||||||
|
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"magnitude2": {
|
||||||
|
"$sqrt": {
|
||||||
|
"$reduce": {
|
||||||
|
"input": query_embedding,
|
||||||
|
"initialValue": 0,
|
||||||
|
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$addFields": {
|
||||||
|
"similarity": {
|
||||||
|
"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$sort": {"similarity": -1}},
|
||||||
|
{"$limit": limit},
|
||||||
|
{"$project": {"content": 1, "similarity": 1, "file_path": 1}}
|
||||||
|
]
|
||||||
|
|
||||||
|
results = list(self.db.db.knowledges.aggregate(pipeline))
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 创建单例实例
|
||||||
|
knowledge_library = KnowledgeLibrary()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
console = Console()
|
||||||
|
console.print("[bold green]知识库处理工具[/bold green]")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
console.print("\n请选择要执行的操作:")
|
||||||
|
console.print("[1] 麦麦开始学习")
|
||||||
|
console.print("[2] 麦麦全部忘光光(仅知识)")
|
||||||
|
console.print("[q] 退出程序")
|
||||||
|
|
||||||
|
choice = input("\n请输入选项: ").strip()
|
||||||
|
|
||||||
|
if choice.lower() == 'q':
|
||||||
|
console.print("[yellow]程序退出[/yellow]")
|
||||||
|
sys.exit(0)
|
||||||
|
elif choice == '2':
|
||||||
|
confirm = input("确定要删除所有知识吗?这个操作不可撤销!(y/n): ").strip().lower()
|
||||||
|
if confirm == 'y':
|
||||||
|
knowledge_library.db.db.knowledges.delete_many({})
|
||||||
|
console.print("[green]已清空所有知识![/green]")
|
||||||
|
continue
|
||||||
|
elif choice == '1':
|
||||||
|
if not os.path.exists(knowledge_library.raw_info_dir):
|
||||||
|
console.print(f"[yellow]创建目录:{knowledge_library.raw_info_dir}[/yellow]")
|
||||||
|
os.makedirs(knowledge_library.raw_info_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 询问分割长度
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
length_input = input("请输入知识分割长度(默认512,输入q退出,回车使用默认值): ").strip()
|
||||||
|
if length_input.lower() == 'q':
|
||||||
|
break
|
||||||
|
if not length_input: # 如果直接回车,使用默认值
|
||||||
|
knowledge_length = 512
|
||||||
|
break
|
||||||
|
knowledge_length = int(length_input)
|
||||||
|
if knowledge_length <= 0:
|
||||||
|
print("分割长度必须大于0,请重新输入")
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
print("请输入有效的数字")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if length_input.lower() == 'q':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 测试知识库功能
|
||||||
|
print(f"开始处理知识库文件,使用分割长度: {knowledge_length}...")
|
||||||
|
knowledge_library.process_files(knowledge_length=knowledge_length)
|
||||||
|
else:
|
||||||
|
console.print("[red]无效的选项,请重新选择[/red]")
|
||||||
|
continue
|
||||||
45
麦麦开始学习.bat
Normal file
45
麦麦开始学习.bat
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
@echo off
|
||||||
|
setlocal enabledelayedexpansion
|
||||||
|
chcp 65001
|
||||||
|
cd /d %~dp0
|
||||||
|
|
||||||
|
echo =====================================
|
||||||
|
echo 选择Python环境:
|
||||||
|
echo 1 - venv (推荐)
|
||||||
|
echo 2 - conda
|
||||||
|
echo =====================================
|
||||||
|
choice /c 12 /n /m "输入数字(1或2): "
|
||||||
|
|
||||||
|
if errorlevel 2 (
|
||||||
|
echo =====================================
|
||||||
|
set "CONDA_ENV="
|
||||||
|
set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
|
||||||
|
|
||||||
|
:: 检查输入是否为空
|
||||||
|
if "!CONDA_ENV!"=="" (
|
||||||
|
echo 错误:环境名称不能为空
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
call conda activate !CONDA_ENV!
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo 激活 conda 环境失败
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo Conda 环境 "!CONDA_ENV!" 激活成功
|
||||||
|
python src/plugins/zhishi/knowledge_library.py
|
||||||
|
) else (
|
||||||
|
if exist "venv\Scripts\python.exe" (
|
||||||
|
venv\Scripts\python src/plugins/zhishi/knowledge_library.py
|
||||||
|
) else (
|
||||||
|
echo =====================================
|
||||||
|
echo 错误: venv环境不存在,请先创建虚拟环境
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
endlocal
|
||||||
|
pause
|
||||||
Reference in New Issue
Block a user