Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -10,13 +10,13 @@ from time import sleep
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.chat.knowledge.src.lpmmconfig import PG_NAMESPACE, global_config
|
||||
from src.chat.knowledge.src.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.src.llm_client import LLMClient
|
||||
from src.chat.knowledge.src.open_ie import OpenIE
|
||||
from src.chat.knowledge.src.kg_manager import KGManager
|
||||
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.llm_client import LLMClient
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_module_logger
|
||||
from src.chat.knowledge.src.utils.hash import get_sha256
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
|
||||
@@ -13,11 +13,11 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from rich.progress import Progress # 替换为 rich 进度条
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from src.chat.knowledge.src.lpmmconfig import global_config
|
||||
from src.chat.knowledge.src.ie_process import info_extract_from_str
|
||||
from src.chat.knowledge.src.llm_client import LLMClient
|
||||
from src.chat.knowledge.src.open_ie import OpenIE
|
||||
from src.chat.knowledge.src.raw_processing import load_raw_data
|
||||
from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
||||
from src.chat.knowledge.llm_client import LLMClient
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.raw_processing import load_raw_data
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
|
||||
@@ -6,7 +6,7 @@ import datetime # 新增导入
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.common.logger_manager import get_logger
|
||||
from src.chat.knowledge.src.lpmmconfig import global_config
|
||||
from src.chat.knowledge.lpmmconfig import global_config
|
||||
|
||||
logger = get_logger("lpmm")
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
@@ -27,7 +27,7 @@ from rich.progress import (
|
||||
)
|
||||
|
||||
install(extra_lines=3)
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
EMBEDDING_DATA_DIR = (
|
||||
os.path.join(ROOT_PATH, "data", "embedding")
|
||||
if global_config["persistence"]["embedding_data_dir"] is None
|
||||
@@ -6,7 +6,7 @@ from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from .lpmmconfig import global_config, INVALID_ENTITY
|
||||
from .llm_client import LLMClient
|
||||
from .utils.json_fix import new_fix_broken_generated_json
|
||||
from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json
|
||||
|
||||
|
||||
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
||||
@@ -31,7 +31,7 @@ from .lpmmconfig import (
|
||||
|
||||
from .global_logger import logger
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
KG_DIR = (
|
||||
os.path.join(ROOT_PATH, "data/rag")
|
||||
if global_config["persistence"]["rag_data_dir"] is None
|
||||
@@ -1,10 +1,10 @@
|
||||
from .src.lpmmconfig import PG_NAMESPACE, global_config
|
||||
from .src.embedding_store import EmbeddingManager
|
||||
from .src.llm_client import LLMClient
|
||||
from .src.mem_active_manager import MemoryActiveManager
|
||||
from .src.qa_manager import QAManager
|
||||
from .src.kg_manager import KGManager
|
||||
from .src.global_logger import logger
|
||||
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.llm_client import LLMClient
|
||||
from src.chat.knowledge.mem_active_manager import MemoryActiveManager
|
||||
from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
# try:
|
||||
# import quick_algo
|
||||
# except ImportError:
|
||||
|
||||
@@ -45,7 +45,7 @@ def _load_config(config, config_file_path):
|
||||
if "llm_providers" in file_config:
|
||||
for provider in file_config["llm_providers"]:
|
||||
if provider["name"] not in config["llm_providers"]:
|
||||
config["llm_providers"][provider["name"]] = dict()
|
||||
config["llm_providers"][provider["name"]] = {}
|
||||
config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"]
|
||||
config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"]
|
||||
|
||||
@@ -135,6 +135,6 @@ global_config = dict(
|
||||
# _load_config(global_config, parser.parse_args().config_path)
|
||||
# file_path = os.path.abspath(__file__)
|
||||
# dir_path = os.path.dirname(file_path)
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
config_path = os.path.join(ROOT_PATH, "config", "lpmm_config.toml")
|
||||
_load_config(global_config, config_path)
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
|
||||
from .global_logger import logger
|
||||
from .lpmmconfig import global_config
|
||||
from .utils.hash import get_sha256
|
||||
from src.chat.knowledge.utils import get_sha256
|
||||
|
||||
|
||||
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
|
||||
@@ -32,6 +32,7 @@ from src.config.official_configs import (
|
||||
FocusChatProcessorConfig,
|
||||
MessageReceiveConfig,
|
||||
MaimMessageConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
RelationshipConfig,
|
||||
)
|
||||
|
||||
@@ -161,6 +162,7 @@ class Config(ConfigBase):
|
||||
experimental: ExperimentalConfig
|
||||
model: ModelConfig
|
||||
maim_message: MaimMessageConfig
|
||||
lpmm_knowledge: LPMMKnowledgeConfig
|
||||
|
||||
|
||||
def load_config(config_path: str) -> Config:
|
||||
|
||||
@@ -414,6 +414,44 @@ class MaimMessageConfig(ConfigBase):
|
||||
"""认证令牌,用于API验证,为空则不启用验证"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LPMMKnowledgeConfig(ConfigBase):
|
||||
"""LPMM知识库配置类"""
|
||||
|
||||
enable: bool = True
|
||||
"""是否启用LPMM知识库"""
|
||||
|
||||
rag_synonym_search_top_k: int = 10
|
||||
"""RAG同义词搜索的Top K数量"""
|
||||
|
||||
rag_synonym_threshold: float = 0.8
|
||||
"""RAG同义词搜索的相似度阈值"""
|
||||
|
||||
info_extraction_workers: int = 3
|
||||
"""信息提取工作线程数"""
|
||||
|
||||
qa_relation_search_top_k: int = 10
|
||||
"""QA关系搜索的Top K数量"""
|
||||
|
||||
qa_relation_threshold: float = 0.75
|
||||
"""QA关系搜索的相似度阈值"""
|
||||
|
||||
qa_paragraph_search_top_k: int = 1000
|
||||
"""QA段落搜索的Top K数量"""
|
||||
|
||||
qa_paragraph_node_weight: float = 0.05
|
||||
"""QA段落节点权重"""
|
||||
|
||||
qa_ent_filter_top_k: int = 10
|
||||
"""QA实体过滤的Top K数量"""
|
||||
|
||||
qa_ppr_damping: float = 0.8
|
||||
"""QA PageRank阻尼系数"""
|
||||
|
||||
qa_res_top_k: int = 10
|
||||
"""QA最终结果的Top K数量"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig(ConfigBase):
|
||||
"""模型配置类"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "2.14.0"
|
||||
version = "2.15.0"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||
#如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
@@ -137,6 +137,18 @@ mood_update_interval = 1.0 # 情绪更新间隔 单位秒
|
||||
mood_decay_rate = 0.95 # 情绪衰减率
|
||||
mood_intensity_factor = 1.0 # 情绪强度因子
|
||||
|
||||
[lpmm_knowledge] # lpmm知识库配置
|
||||
enable = true # 是否启用lpmm知识库
|
||||
rag_synonym_search_top_k = 10 # 同义词搜索TopK
|
||||
rag_synonym_threshold = 0.8 # 同义词阈值(相似度高于此阈值的词语会被认为是同义词)
|
||||
info_extraction_workers = 3 # 实体提取同时执行线程数,非Pro模型不要设置超过5
|
||||
qa_relation_search_top_k = 10 # 关系搜索TopK
|
||||
qa_relation_threshold = 0.5 # 关系阈值(相似度高于此阈值的关系会被认为是相关的关系)
|
||||
qa_paragraph_search_top_k = 1000 # 段落搜索TopK(不能过小,可能影响搜索结果)
|
||||
qa_paragraph_node_weight = 0.05 # 段落节点权重(在图搜索&PPR计算中的权重,当搜索仅使用DPR时,此参数不起作用)
|
||||
qa_ent_filter_top_k = 10 # 实体过滤TopK
|
||||
qa_ppr_damping = 0.8 # PPR阻尼系数
|
||||
qa_res_top_k = 3 # 最终提供的文段TopK
|
||||
|
||||
# keyword_rules 用于设置关键词触发的额外回复知识
|
||||
# 添加新规则方法:在 keyword_rules 数组中增加一项,格式如下:
|
||||
@@ -273,7 +285,30 @@ temp = 0.7
|
||||
enable_thinking = false # 是否启用思考(qwen3 only)
|
||||
|
||||
|
||||
#------------LPMM知识库模型------------
|
||||
|
||||
[model.lpmm_entity_extract] # 实体提取模型
|
||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||
provider = "SILICONFLOW"
|
||||
pri_in = 2
|
||||
pri_out = 8
|
||||
temp = 0.2
|
||||
|
||||
|
||||
[model.lpmm_rdf_build] # RDF构建模型
|
||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||
provider = "SILICONFLOW"
|
||||
pri_in = 2
|
||||
pri_out = 8
|
||||
temp = 0.2
|
||||
|
||||
|
||||
[model.lpmm_qa] # 问答模型
|
||||
name = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
|
||||
provider = "SILICONFLOW"
|
||||
pri_in = 4.0
|
||||
pri_out = 16.0
|
||||
temp = 0.7
|
||||
|
||||
|
||||
[maim_message]
|
||||
@@ -296,3 +331,4 @@ enable_friend_chat = false # 是否启用好友聊天
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user