Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -20,3 +20,4 @@
|
|||||||
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
|
- `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。
|
||||||
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
- `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。
|
||||||
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
- `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。
|
||||||
|
4. 现在增加了参数类型检查,完善了对应注释
|
||||||
@@ -9,19 +9,17 @@ import os
|
|||||||
from time import sleep
|
from time import sleep
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config
|
|
||||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||||
from src.chat.knowledge.llm_client import LLMClient
|
|
||||||
from src.chat.knowledge.open_ie import OpenIE
|
from src.chat.knowledge.open_ie import OpenIE
|
||||||
from src.chat.knowledge.kg_manager import KGManager
|
from src.chat.knowledge.kg_manager import KGManager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.knowledge.utils.hash import get_sha256
|
from src.chat.knowledge.utils.hash import get_sha256
|
||||||
|
from src.manager.local_store_manager import local_storage
|
||||||
|
|
||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
# 添加项目根目录到 sys.path
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
OPENIE_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie")
|
OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||||
|
|
||||||
logger = get_logger("OpenIE导入")
|
logger = get_logger("OpenIE导入")
|
||||||
|
|
||||||
@@ -63,7 +61,7 @@ def hash_deduplicate(
|
|||||||
):
|
):
|
||||||
# 段落hash
|
# 段落hash
|
||||||
paragraph_hash = get_sha256(raw_paragraph)
|
paragraph_hash = get_sha256(raw_paragraph)
|
||||||
if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
||||||
continue
|
continue
|
||||||
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
||||||
new_triple_list_data[paragraph_hash] = triple_list
|
new_triple_list_data[paragraph_hash] = triple_list
|
||||||
@@ -193,15 +191,9 @@ def main(): # sourcery skip: dict-comprehension
|
|||||||
logger.info("----开始导入openie数据----\n")
|
logger.info("----开始导入openie数据----\n")
|
||||||
|
|
||||||
logger.info("创建LLM客户端")
|
logger.info("创建LLM客户端")
|
||||||
llm_client_list = {}
|
|
||||||
for key in global_config["llm_providers"]:
|
|
||||||
llm_client_list[key] = LLMClient(
|
|
||||||
global_config["llm_providers"][key]["base_url"],
|
|
||||||
global_config["llm_providers"][key]["api_key"],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 初始化Embedding库
|
# 初始化Embedding库
|
||||||
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
|
embed_manager = EmbeddingManager()
|
||||||
logger.info("正在从文件加载Embedding库")
|
logger.info("正在从文件加载Embedding库")
|
||||||
try:
|
try:
|
||||||
embed_manager.load_from_file()
|
embed_manager.load_from_file()
|
||||||
@@ -230,7 +222,7 @@ def main(): # sourcery skip: dict-comprehension
|
|||||||
|
|
||||||
# 数据比对:Embedding库与KG的段落hash集合
|
# 数据比对:Embedding库与KG的段落hash集合
|
||||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||||
key = f"{PG_NAMESPACE}-{pg_hash}"
|
key = f"{local_storage['pg_namespace']}-{pg_hash}"
|
||||||
if key not in embed_manager.stored_pg_hashes:
|
if key not in embed_manager.stored_pg_hashes:
|
||||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import signal
|
|||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from threading import Lock, Event
|
from threading import Lock, Event
|
||||||
import sys
|
import sys
|
||||||
import glob
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
@@ -13,11 +12,9 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|||||||
from rich.progress import Progress # 替换为 rich 进度条
|
from rich.progress import Progress # 替换为 rich 进度条
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.knowledge.lpmmconfig import global_config
|
# from src.chat.knowledge.lpmmconfig import global_config
|
||||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
from src.chat.knowledge.ie_process import info_extract_from_str
|
||||||
from src.chat.knowledge.llm_client import LLMClient
|
|
||||||
from src.chat.knowledge.open_ie import OpenIE
|
from src.chat.knowledge.open_ie import OpenIE
|
||||||
from src.chat.knowledge.raw_processing import load_raw_data
|
|
||||||
from rich.progress import (
|
from rich.progress import (
|
||||||
BarColumn,
|
BarColumn,
|
||||||
TimeElapsedColumn,
|
TimeElapsedColumn,
|
||||||
@@ -27,24 +24,17 @@ from rich.progress import (
|
|||||||
SpinnerColumn,
|
SpinnerColumn,
|
||||||
TextColumn,
|
TextColumn,
|
||||||
)
|
)
|
||||||
|
from raw_data_preprocessor import RAW_DATA_PATH, process_multi_files, load_raw_data
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
logger = get_logger("LPMM知识库-信息提取")
|
logger = get_logger("LPMM知识库-信息提取")
|
||||||
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||||
IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join(
|
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||||
ROOT_PATH, "data", "imported_lpmm_data"
|
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||||
)
|
|
||||||
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie")
|
|
||||||
|
|
||||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
|
||||||
file_lock = Lock()
|
|
||||||
open_ie_doc_lock = Lock()
|
|
||||||
|
|
||||||
# 创建一个事件标志,用于控制程序终止
|
|
||||||
shutdown_event = Event()
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_dirs():
|
def ensure_dirs():
|
||||||
"""确保临时目录和输出目录存在"""
|
"""确保临时目录和输出目录存在"""
|
||||||
@@ -54,12 +44,26 @@ def ensure_dirs():
|
|||||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||||
if not os.path.exists(IMPORTED_DATA_PATH):
|
if not os.path.exists(RAW_DATA_PATH):
|
||||||
os.makedirs(IMPORTED_DATA_PATH)
|
os.makedirs(RAW_DATA_PATH)
|
||||||
logger.info(f"已创建导入数据目录: {IMPORTED_DATA_PATH}")
|
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
|
||||||
|
|
||||||
|
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||||
|
file_lock = Lock()
|
||||||
|
open_ie_doc_lock = Lock()
|
||||||
|
|
||||||
def process_single_text(pg_hash, raw_data, llm_client_list):
|
# 创建一个事件标志,用于控制程序终止
|
||||||
|
shutdown_event = Event()
|
||||||
|
|
||||||
|
lpmm_entity_extract_llm = LLMRequest(
|
||||||
|
model=global_config.model.lpmm_entity_extract,
|
||||||
|
request_type="lpmm.entity_extract"
|
||||||
|
)
|
||||||
|
lpmm_rdf_build_llm = LLMRequest(
|
||||||
|
model=global_config.model.lpmm_rdf_build,
|
||||||
|
request_type="lpmm.rdf_build"
|
||||||
|
)
|
||||||
|
def process_single_text(pg_hash, raw_data):
|
||||||
"""处理单个文本的函数,用于线程池"""
|
"""处理单个文本的函数,用于线程池"""
|
||||||
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
||||||
|
|
||||||
@@ -77,8 +81,8 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
|
|||||||
os.remove(temp_file_path)
|
os.remove(temp_file_path)
|
||||||
|
|
||||||
entity_list, rdf_triple_list = info_extract_from_str(
|
entity_list, rdf_triple_list = info_extract_from_str(
|
||||||
llm_client_list[global_config["entity_extract"]["llm"]["provider"]],
|
lpmm_entity_extract_llm,
|
||||||
llm_client_list[global_config["rdf_build"]["llm"]["provider"]],
|
lpmm_rdf_build_llm,
|
||||||
raw_data,
|
raw_data,
|
||||||
)
|
)
|
||||||
if entity_list is None or rdf_triple_list is None:
|
if entity_list is None or rdf_triple_list is None:
|
||||||
@@ -113,7 +117,7 @@ def signal_handler(_signum, _frame):
|
|||||||
def main(): # sourcery skip: comprehension-to-generator, extract-method
|
def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||||
# 设置信号处理器
|
# 设置信号处理器
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
ensure_dirs() # 确保目录存在
|
||||||
# 新增用户确认提示
|
# 新增用户确认提示
|
||||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||||
@@ -130,50 +134,17 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
|||||||
ensure_dirs() # 确保目录存在
|
ensure_dirs() # 确保目录存在
|
||||||
logger.info("--------进行信息提取--------\n")
|
logger.info("--------进行信息提取--------\n")
|
||||||
|
|
||||||
logger.info("创建LLM客户端")
|
# 加载原始数据
|
||||||
llm_client_list = {
|
logger.info("正在加载原始数据")
|
||||||
key: LLMClient(
|
all_sha256_list, all_raw_datas = load_raw_data()
|
||||||
global_config["llm_providers"][key]["base_url"],
|
|
||||||
global_config["llm_providers"][key]["api_key"],
|
|
||||||
)
|
|
||||||
for key in global_config["llm_providers"]
|
|
||||||
}
|
|
||||||
# 检查 openie 输出目录
|
|
||||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
|
||||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
|
||||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
|
||||||
|
|
||||||
# 确保 TEMP_DIR 目录存在
|
|
||||||
if not os.path.exists(TEMP_DIR):
|
|
||||||
os.makedirs(TEMP_DIR)
|
|
||||||
logger.info(f"已创建缓存目录: {TEMP_DIR}")
|
|
||||||
|
|
||||||
# 遍历IMPORTED_DATA_PATH下所有json文件
|
|
||||||
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
|
|
||||||
if not imported_files:
|
|
||||||
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
all_sha256_list = []
|
|
||||||
all_raw_datas = []
|
|
||||||
|
|
||||||
for imported_file in imported_files:
|
|
||||||
logger.info(f"正在处理文件: {imported_file}")
|
|
||||||
try:
|
|
||||||
sha256_list, raw_datas = load_raw_data(imported_file)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
|
|
||||||
continue
|
|
||||||
all_sha256_list.extend(sha256_list)
|
|
||||||
all_raw_datas.extend(raw_datas)
|
|
||||||
|
|
||||||
failed_sha256 = []
|
failed_sha256 = []
|
||||||
open_ie_doc = []
|
open_ie_doc = []
|
||||||
|
|
||||||
workers = global_config["info_extraction"]["workers"]
|
workers = global_config.lpmm_knowledge.info_extraction_workers
|
||||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||||
future_to_hash = {
|
future_to_hash = {
|
||||||
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
|
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
|
||||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,40 +1,16 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys # 新增系统模块导入
|
import sys # 新增系统模块导入
|
||||||
import datetime # 新增导入
|
from src.chat.knowledge.utils.hash import get_sha256
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.knowledge.lpmmconfig import global_config
|
|
||||||
|
|
||||||
logger = get_logger("lpmm")
|
logger = get_logger("lpmm")
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||||
# 新增:确保 RAW_DATA_PATH 存在
|
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||||
if not os.path.exists(RAW_DATA_PATH):
|
|
||||||
os.makedirs(RAW_DATA_PATH, exist_ok=True)
|
|
||||||
logger.info(f"已创建目录: {RAW_DATA_PATH}")
|
|
||||||
|
|
||||||
if global_config.get("persistence", {}).get("raw_data_path") is not None:
|
def _process_text_file(file_path):
|
||||||
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"])
|
|
||||||
else:
|
|
||||||
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
|
||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
|
||||||
|
|
||||||
|
|
||||||
def check_and_create_dirs():
|
|
||||||
"""检查并创建必要的目录"""
|
|
||||||
required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
|
|
||||||
|
|
||||||
for dir_path in required_dirs:
|
|
||||||
if not os.path.exists(dir_path):
|
|
||||||
os.makedirs(dir_path)
|
|
||||||
logger.info(f"已创建目录: {dir_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def process_text_file(file_path):
|
|
||||||
"""处理单个文本文件,返回段落列表"""
|
"""处理单个文本文件,返回段落列表"""
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
raw = f.read()
|
raw = f.read()
|
||||||
@@ -55,54 +31,45 @@ def process_text_file(file_path):
|
|||||||
return paragraphs
|
return paragraphs
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def _process_multi_files() -> list:
|
||||||
# 新增用户确认提示
|
|
||||||
print("=== 数据预处理脚本 ===")
|
|
||||||
print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。")
|
|
||||||
print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。")
|
|
||||||
print("请确保原始数据已放置在正确的目录中。")
|
|
||||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
|
||||||
if confirm != "y":
|
|
||||||
logger.info("操作已取消")
|
|
||||||
sys.exit(1)
|
|
||||||
print("\n" + "=" * 40 + "\n")
|
|
||||||
|
|
||||||
# 检查并创建必要的目录
|
|
||||||
check_and_create_dirs()
|
|
||||||
|
|
||||||
# # 检查输出文件是否存在
|
|
||||||
# if os.path.exists(RAW_DATA_PATH):
|
|
||||||
# logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
|
|
||||||
# sys.exit(1)
|
|
||||||
|
|
||||||
# if os.path.exists(RAW_DATA_PATH):
|
|
||||||
# logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
|
|
||||||
# sys.exit(1)
|
|
||||||
|
|
||||||
# 获取所有原始文本文件
|
|
||||||
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||||
if not raw_files:
|
if not raw_files:
|
||||||
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 处理所有文件
|
# 处理所有文件
|
||||||
all_paragraphs = []
|
all_paragraphs = []
|
||||||
for file in raw_files:
|
for file in raw_files:
|
||||||
logger.info(f"正在处理文件: {file.name}")
|
logger.info(f"正在处理文件: {file.name}")
|
||||||
paragraphs = process_text_file(file)
|
paragraphs = _process_text_file(file)
|
||||||
all_paragraphs.extend(paragraphs)
|
all_paragraphs.extend(paragraphs)
|
||||||
|
return all_paragraphs
|
||||||
|
|
||||||
# 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json
|
def load_raw_data() -> tuple[list[str], list[str]]:
|
||||||
now = datetime.datetime.now()
|
"""加载原始数据文件
|
||||||
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
|
|
||||||
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
|
|
||||||
|
|
||||||
logger.info(f"处理完成,结果已保存到: {output_path}")
|
读取原始数据文件,将原始数据加载到内存中
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 可选,指定要读取的json文件绝对路径
|
||||||
|
|
||||||
if __name__ == "__main__":
|
Returns:
|
||||||
logger.info(f"原始数据路径: {RAW_DATA_PATH}")
|
- raw_data: 原始数据列表
|
||||||
logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}")
|
- sha256_list: 原始数据的SHA256集合
|
||||||
main()
|
"""
|
||||||
|
raw_data = _process_multi_files()
|
||||||
|
sha256_list = []
|
||||||
|
sha256_set = set()
|
||||||
|
for item in raw_data:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
logger.warning(f"数据类型错误:{item}")
|
||||||
|
continue
|
||||||
|
pg_hash = get_sha256(item)
|
||||||
|
if pg_hash in sha256_set:
|
||||||
|
logger.warning(f"重复数据:{item}")
|
||||||
|
continue
|
||||||
|
sha256_set.add(pg_hash)
|
||||||
|
sha256_list.append(pg_hash)
|
||||||
|
raw_data.append(item)
|
||||||
|
logger.info(f"共读取到{len(raw_data)}条数据")
|
||||||
|
|
||||||
|
return sha256_list, raw_data
|
||||||
@@ -47,7 +47,7 @@ class MaiEmoji:
|
|||||||
self.embedding = []
|
self.embedding = []
|
||||||
self.hash = "" # 初始为空,在创建实例时会计算
|
self.hash = "" # 初始为空,在创建实例时会计算
|
||||||
self.description = ""
|
self.description = ""
|
||||||
self.emotion = []
|
self.emotion: List[str] = []
|
||||||
self.usage_count = 0
|
self.usage_count = 0
|
||||||
self.last_used_time = time.time()
|
self.last_used_time = time.time()
|
||||||
self.register_time = time.time()
|
self.register_time = time.time()
|
||||||
|
|||||||
@@ -243,6 +243,8 @@ class HeartFChatting:
|
|||||||
loop_start_time = time.time()
|
loop_start_time = time.time()
|
||||||
await self.relationship_builder.build_relation()
|
await self.relationship_builder.build_relation()
|
||||||
|
|
||||||
|
available_actions = {}
|
||||||
|
|
||||||
# 第一步:动作修改
|
# 第一步:动作修改
|
||||||
with Timer("动作修改", cycle_timers):
|
with Timer("动作修改", cycle_timers):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import pandas as pd
|
|||||||
# import tqdm
|
# import tqdm
|
||||||
import faiss
|
import faiss
|
||||||
|
|
||||||
from .llm_client import LLMClient
|
# from .llm_client import LLMClient
|
||||||
from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config
|
from .lpmmconfig import global_config
|
||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
@@ -25,6 +25,9 @@ from rich.progress import (
|
|||||||
SpinnerColumn,
|
SpinnerColumn,
|
||||||
TextColumn,
|
TextColumn,
|
||||||
)
|
)
|
||||||
|
from src.manager.local_store_manager import local_storage
|
||||||
|
from src.chat.utils.utils import get_embedding
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
@@ -86,21 +89,20 @@ class EmbeddingStoreItem:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingStore:
|
class EmbeddingStore:
|
||||||
def __init__(self, llm_client: LLMClient, namespace: str, dir_path: str):
|
def __init__(self, namespace: str, dir_path: str):
|
||||||
self.namespace = namespace
|
self.namespace = namespace
|
||||||
self.llm_client = llm_client
|
|
||||||
self.dir = dir_path
|
self.dir = dir_path
|
||||||
self.embedding_file_path = dir_path + "/" + namespace + ".parquet"
|
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||||
self.index_file_path = dir_path + "/" + namespace + ".index"
|
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||||
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
||||||
|
|
||||||
self.store = dict()
|
self.store = {}
|
||||||
|
|
||||||
self.faiss_index = None
|
self.faiss_index = None
|
||||||
self.idx2hash = None
|
self.idx2hash = None
|
||||||
|
|
||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
|
return get_embedding(s)
|
||||||
|
|
||||||
def get_test_file_path(self):
|
def get_test_file_path(self):
|
||||||
return EMBEDDING_TEST_FILE
|
return EMBEDDING_TEST_FILE
|
||||||
@@ -293,20 +295,17 @@ class EmbeddingStore:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingManager:
|
class EmbeddingManager:
|
||||||
def __init__(self, llm_client: LLMClient):
|
def __init__(self):
|
||||||
self.paragraphs_embedding_store = EmbeddingStore(
|
self.paragraphs_embedding_store = EmbeddingStore(
|
||||||
llm_client,
|
local_storage['pg_namespace'],
|
||||||
PG_NAMESPACE,
|
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
)
|
)
|
||||||
self.entities_embedding_store = EmbeddingStore(
|
self.entities_embedding_store = EmbeddingStore(
|
||||||
llm_client,
|
local_storage['pg_namespace'],
|
||||||
ENT_NAMESPACE,
|
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
)
|
)
|
||||||
self.relation_embedding_store = EmbeddingStore(
|
self.relation_embedding_store = EmbeddingStore(
|
||||||
llm_client,
|
local_storage['pg_namespace'],
|
||||||
REL_NAMESPACE,
|
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
)
|
)
|
||||||
self.stored_pg_hashes = set()
|
self.stored_pg_hashes = set()
|
||||||
|
|||||||
@@ -4,28 +4,35 @@ from typing import List, Union
|
|||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
from . import prompt_template
|
from . import prompt_template
|
||||||
from .lpmmconfig import global_config, INVALID_ENTITY
|
from .knowledge_lib import INVALID_ENTITY
|
||||||
from .llm_client import LLMClient
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json
|
from json_repair import repair_json
|
||||||
|
def _extract_json_from_text(text: str) -> dict:
|
||||||
|
"""从文本中提取JSON数据的高容错方法"""
|
||||||
|
try:
|
||||||
|
fixed_json = repair_json(text)
|
||||||
|
if isinstance(fixed_json, str):
|
||||||
|
parsed_json = json.loads(fixed_json)
|
||||||
|
else:
|
||||||
|
parsed_json = fixed_json
|
||||||
|
|
||||||
|
if isinstance(parsed_json, list) and parsed_json:
|
||||||
|
parsed_json = parsed_json[0]
|
||||||
|
|
||||||
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
if isinstance(parsed_json, dict):
|
||||||
|
return parsed_json
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...")
|
||||||
|
|
||||||
|
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||||
_, request_result = llm_client.send_chat_request(
|
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
|
||||||
global_config["entity_extract"]["llm"]["model"], entity_extract_context
|
|
||||||
)
|
|
||||||
|
|
||||||
# 去除‘{’前的内容(结果中可能有多个‘{’)
|
|
||||||
if "[" in request_result:
|
|
||||||
request_result = request_result[request_result.index("[") :]
|
|
||||||
|
|
||||||
# 去除最后一个‘}’后的内容(结果中可能有多个‘}’)
|
|
||||||
if "]" in request_result:
|
|
||||||
request_result = request_result[: request_result.rindex("]") + 1]
|
|
||||||
|
|
||||||
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
|
|
||||||
|
|
||||||
|
entity_extract_result = _extract_json_from_text(response)
|
||||||
|
# 尝试load JSON数据
|
||||||
|
json.loads(entity_extract_result)
|
||||||
entity_extract_result = [
|
entity_extract_result = [
|
||||||
entity
|
entity
|
||||||
for entity in entity_extract_result
|
for entity in entity_extract_result
|
||||||
@@ -38,23 +45,16 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
|||||||
return entity_extract_result
|
return entity_extract_result
|
||||||
|
|
||||||
|
|
||||||
def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]:
|
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
|
||||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||||
entity_extract_context = prompt_template.build_rdf_triple_extract_context(
|
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||||
)
|
)
|
||||||
_, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context)
|
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
|
||||||
|
|
||||||
# 去除‘{’前的内容(结果中可能有多个‘{’)
|
|
||||||
if "[" in request_result:
|
|
||||||
request_result = request_result[request_result.index("[") :]
|
|
||||||
|
|
||||||
# 去除最后一个‘}’后的内容(结果中可能有多个‘}’)
|
|
||||||
if "]" in request_result:
|
|
||||||
request_result = request_result[: request_result.rindex("]") + 1]
|
|
||||||
|
|
||||||
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
|
|
||||||
|
|
||||||
|
entity_extract_result = _extract_json_from_text(response)
|
||||||
|
# 尝试load JSON数据
|
||||||
|
json.loads(entity_extract_result)
|
||||||
for triple in entity_extract_result:
|
for triple in entity_extract_result:
|
||||||
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||||
raise Exception("RDF提取结果格式错误")
|
raise Exception("RDF提取结果格式错误")
|
||||||
@@ -63,7 +63,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -
|
|||||||
|
|
||||||
|
|
||||||
def info_extract_from_str(
|
def info_extract_from_str(
|
||||||
llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str
|
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
|
||||||
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
||||||
try_count = 0
|
try_count = 0
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -20,24 +20,37 @@ from quick_algo import di_graph, pagerank
|
|||||||
|
|
||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||||
from .lpmmconfig import (
|
from .lpmmconfig import global_config
|
||||||
ENT_NAMESPACE,
|
from src.manager.local_store_manager import local_storage
|
||||||
PG_NAMESPACE,
|
|
||||||
RAG_ENT_CNT_NAMESPACE,
|
|
||||||
RAG_GRAPH_NAMESPACE,
|
|
||||||
RAG_PG_HASH_NAMESPACE,
|
|
||||||
global_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
||||||
KG_DIR = (
|
def _get_kg_dir():
|
||||||
os.path.join(ROOT_PATH, "data/rag")
|
"""
|
||||||
if global_config["persistence"]["rag_data_dir"] is None
|
安全地获取KG数据目录路径
|
||||||
else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"])
|
"""
|
||||||
)
|
root_path = local_storage['root_path']
|
||||||
KG_DIR_STR = str(KG_DIR).replace("\\", "/")
|
if root_path is None:
|
||||||
|
# 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
|
||||||
|
logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}")
|
||||||
|
|
||||||
|
# 获取RAG数据目录
|
||||||
|
rag_data_dir = global_config["persistence"]["rag_data_dir"]
|
||||||
|
if rag_data_dir is None:
|
||||||
|
kg_dir = os.path.join(root_path, "data/rag")
|
||||||
|
else:
|
||||||
|
kg_dir = os.path.join(root_path, rag_data_dir)
|
||||||
|
|
||||||
|
return str(kg_dir).replace("\\", "/")
|
||||||
|
|
||||||
|
|
||||||
|
# 延迟初始化,避免在模块加载时就访问可能未初始化的 local_storage
|
||||||
|
def get_kg_dir_str():
|
||||||
|
"""获取KG目录字符串"""
|
||||||
|
return _get_kg_dir()
|
||||||
|
|
||||||
|
|
||||||
class KGManager:
|
class KGManager:
|
||||||
@@ -46,15 +59,15 @@ class KGManager:
|
|||||||
# 存储段落的hash值,用于去重
|
# 存储段落的hash值,用于去重
|
||||||
self.stored_paragraph_hashes = set()
|
self.stored_paragraph_hashes = set()
|
||||||
# 实体出现次数
|
# 实体出现次数
|
||||||
self.ent_appear_cnt = dict()
|
self.ent_appear_cnt = {}
|
||||||
# KG
|
# KG
|
||||||
self.graph = di_graph.DiGraph()
|
self.graph = di_graph.DiGraph()
|
||||||
|
|
||||||
# 持久化相关
|
# 持久化相关 - 使用延迟初始化的路径
|
||||||
self.dir_path = KG_DIR_STR
|
self.dir_path = get_kg_dir_str()
|
||||||
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml"
|
self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml"
|
||||||
self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
|
self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet"
|
||||||
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
|
self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json"
|
||||||
|
|
||||||
def save_to_file(self):
|
def save_to_file(self):
|
||||||
"""将KG数据保存到文件"""
|
"""将KG数据保存到文件"""
|
||||||
@@ -109,8 +122,8 @@ class KGManager:
|
|||||||
# 避免自连接
|
# 避免自连接
|
||||||
continue
|
continue
|
||||||
# 一个triple就是一条边(同时构建双向联系)
|
# 一个triple就是一条边(同时构建双向联系)
|
||||||
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
|
hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
|
||||||
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
|
hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2])
|
||||||
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
||||||
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
||||||
entity_set.add(hash_key1)
|
entity_set.add(hash_key1)
|
||||||
@@ -128,8 +141,8 @@ class KGManager:
|
|||||||
"""构建实体节点与文段节点之间的关系"""
|
"""构建实体节点与文段节点之间的关系"""
|
||||||
for idx in triple_list_data:
|
for idx in triple_list_data:
|
||||||
for triple in triple_list_data[idx]:
|
for triple in triple_list_data[idx]:
|
||||||
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
|
ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
|
||||||
pg_hash_key = PG_NAMESPACE + "-" + str(idx)
|
pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx)
|
||||||
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -144,8 +157,8 @@ class KGManager:
|
|||||||
ent_hash_list = set()
|
ent_hash_list = set()
|
||||||
for triple_list in triple_list_data.values():
|
for triple_list in triple_list_data.values():
|
||||||
for triple in triple_list:
|
for triple in triple_list:
|
||||||
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0]))
|
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0]))
|
||||||
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2]))
|
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2]))
|
||||||
ent_hash_list = list(ent_hash_list)
|
ent_hash_list = list(ent_hash_list)
|
||||||
|
|
||||||
synonym_hash_set = set()
|
synonym_hash_set = set()
|
||||||
@@ -250,7 +263,7 @@ class KGManager:
|
|||||||
for src_tgt in node_to_node.keys():
|
for src_tgt in node_to_node.keys():
|
||||||
for node_hash in src_tgt:
|
for node_hash in src_tgt:
|
||||||
if node_hash not in existed_nodes:
|
if node_hash not in existed_nodes:
|
||||||
if node_hash.startswith(ENT_NAMESPACE):
|
if node_hash.startswith(local_storage['ent_namespace']):
|
||||||
# 新增实体节点
|
# 新增实体节点
|
||||||
node = embedding_manager.entities_embedding_store.store[node_hash]
|
node = embedding_manager.entities_embedding_store.store[node_hash]
|
||||||
assert isinstance(node, EmbeddingStoreItem)
|
assert isinstance(node, EmbeddingStoreItem)
|
||||||
@@ -259,7 +272,7 @@ class KGManager:
|
|||||||
node_item["type"] = "ent"
|
node_item["type"] = "ent"
|
||||||
node_item["create_time"] = now_time
|
node_item["create_time"] = now_time
|
||||||
self.graph.update_node(node_item)
|
self.graph.update_node(node_item)
|
||||||
elif node_hash.startswith(PG_NAMESPACE):
|
elif node_hash.startswith(local_storage['pg_namespace']):
|
||||||
# 新增文段节点
|
# 新增文段节点
|
||||||
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
|
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
|
||||||
assert isinstance(node, EmbeddingStoreItem)
|
assert isinstance(node, EmbeddingStoreItem)
|
||||||
@@ -340,7 +353,7 @@ class KGManager:
|
|||||||
# 关系三元组
|
# 关系三元组
|
||||||
triple = relation[2:-2].split("', '")
|
triple = relation[2:-2].split("', '")
|
||||||
for ent in [(triple[0]), (triple[2])]:
|
for ent in [(triple[0]), (triple[2])]:
|
||||||
ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent)
|
ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent)
|
||||||
if ent_hash in existed_nodes: # 该实体需在KG中存在
|
if ent_hash in existed_nodes: # 该实体需在KG中存在
|
||||||
if ent_hash not in ent_sim_scores: # 尚未记录的实体
|
if ent_hash not in ent_sim_scores: # 尚未记录的实体
|
||||||
ent_sim_scores[ent_hash] = []
|
ent_sim_scores[ent_hash] = []
|
||||||
@@ -418,7 +431,7 @@ class KGManager:
|
|||||||
# 获取最终结果
|
# 获取最终结果
|
||||||
# 从搜索结果中提取文段节点的结果
|
# 从搜索结果中提取文段节点的结果
|
||||||
passage_node_res = [
|
passage_node_res = [
|
||||||
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE)
|
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace'])
|
||||||
]
|
]
|
||||||
del ppr_res
|
del ppr_res
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config
|
from src.chat.knowledge.lpmmconfig import global_config
|
||||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||||
from src.chat.knowledge.llm_client import LLMClient
|
from src.chat.knowledge.llm_client import LLMClient
|
||||||
from src.chat.knowledge.mem_active_manager import MemoryActiveManager
|
from src.chat.knowledge.mem_active_manager import MemoryActiveManager
|
||||||
@@ -6,10 +6,80 @@ from src.chat.knowledge.qa_manager import QAManager
|
|||||||
from src.chat.knowledge.kg_manager import KGManager
|
from src.chat.knowledge.kg_manager import KGManager
|
||||||
from src.chat.knowledge.global_logger import logger
|
from src.chat.knowledge.global_logger import logger
|
||||||
from src.config.config import global_config as bot_global_config
|
from src.config.config import global_config as bot_global_config
|
||||||
# try:
|
from src.manager.local_store_manager import local_storage
|
||||||
# import quick_algo
|
import os
|
||||||
# except ImportError:
|
|
||||||
# print("quick_algo not found, please install it first")
|
INVALID_ENTITY = [
|
||||||
|
"",
|
||||||
|
"你",
|
||||||
|
"他",
|
||||||
|
"她",
|
||||||
|
"它",
|
||||||
|
"我们",
|
||||||
|
"你们",
|
||||||
|
"他们",
|
||||||
|
"她们",
|
||||||
|
"它们",
|
||||||
|
]
|
||||||
|
PG_NAMESPACE = "paragraph"
|
||||||
|
ENT_NAMESPACE = "entity"
|
||||||
|
REL_NAMESPACE = "relation"
|
||||||
|
|
||||||
|
RAG_GRAPH_NAMESPACE = "rag-graph"
|
||||||
|
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
||||||
|
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
||||||
|
|
||||||
|
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
|
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||||
|
|
||||||
|
def _initialize_knowledge_local_storage():
|
||||||
|
"""
|
||||||
|
初始化知识库相关的本地存储配置
|
||||||
|
使用字典批量设置,避免重复的if判断
|
||||||
|
"""
|
||||||
|
# 定义所有需要初始化的配置项
|
||||||
|
default_configs = {
|
||||||
|
# 路径配置
|
||||||
|
'root_path': ROOT_PATH,
|
||||||
|
'data_path': f"{ROOT_PATH}/data",
|
||||||
|
|
||||||
|
# 实体和命名空间配置
|
||||||
|
'lpmm_invalid_entity': INVALID_ENTITY,
|
||||||
|
'pg_namespace': PG_NAMESPACE,
|
||||||
|
'ent_namespace': ENT_NAMESPACE,
|
||||||
|
'rel_namespace': REL_NAMESPACE,
|
||||||
|
|
||||||
|
# RAG相关命名空间配置
|
||||||
|
'rag_graph_namespace': RAG_GRAPH_NAMESPACE,
|
||||||
|
'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE,
|
||||||
|
'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE
|
||||||
|
}
|
||||||
|
|
||||||
|
# 日志级别映射:重要配置用info,其他用debug
|
||||||
|
important_configs = {'root_path', 'data_path'}
|
||||||
|
|
||||||
|
# 批量设置配置项
|
||||||
|
initialized_count = 0
|
||||||
|
for key, default_value in default_configs.items():
|
||||||
|
if local_storage[key] is None:
|
||||||
|
local_storage[key] = default_value
|
||||||
|
|
||||||
|
# 根据重要性选择日志级别
|
||||||
|
if key in important_configs:
|
||||||
|
logger.info(f"设置{key}: {default_value}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"设置{key}: {default_value}")
|
||||||
|
|
||||||
|
initialized_count += 1
|
||||||
|
|
||||||
|
if initialized_count > 0:
|
||||||
|
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
|
||||||
|
else:
|
||||||
|
logger.debug("知识库本地存储配置已存在,跳过初始化")
|
||||||
|
|
||||||
|
# 初始化本地存储路径
|
||||||
|
_initialize_knowledge_local_storage()
|
||||||
|
|
||||||
# 检查LPMM知识库是否启用
|
# 检查LPMM知识库是否启用
|
||||||
if bot_global_config.lpmm_knowledge.enable:
|
if bot_global_config.lpmm_knowledge.enable:
|
||||||
@@ -23,7 +93,7 @@ if bot_global_config.lpmm_knowledge.enable:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 初始化Embedding库
|
# 初始化Embedding库
|
||||||
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
|
embed_manager = EmbeddingManager()
|
||||||
logger.info("正在从文件加载Embedding库")
|
logger.info("正在从文件加载Embedding库")
|
||||||
try:
|
try:
|
||||||
embed_manager.load_from_file()
|
embed_manager.load_from_file()
|
||||||
@@ -54,9 +124,6 @@ if bot_global_config.lpmm_knowledge.enable:
|
|||||||
qa_manager = QAManager(
|
qa_manager = QAManager(
|
||||||
embed_manager,
|
embed_manager,
|
||||||
kg_manager,
|
kg_manager,
|
||||||
llm_client_list[global_config["embedding"]["provider"]],
|
|
||||||
llm_client_list[global_config["qa"]["llm"]["provider"]],
|
|
||||||
llm_client_list[global_config["qa"]["llm"]["provider"]],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记忆激活(用于记忆库)
|
# 记忆激活(用于记忆库)
|
||||||
|
|||||||
@@ -4,9 +4,8 @@ import glob
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
from .lpmmconfig import INVALID_ENTITY, global_config
|
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
|
||||||
|
# from src.manager.local_store_manager import local_storage
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
||||||
|
|
||||||
|
|
||||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||||
@@ -107,7 +106,7 @@ class OpenIE:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def load() -> "OpenIE":
|
def load() -> "OpenIE":
|
||||||
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
|
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
|
||||||
openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
|
openie_dir = os.path.join(DATA_PATH, "openie")
|
||||||
if not os.path.exists(openie_dir):
|
if not os.path.exists(openie_dir):
|
||||||
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
|
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
|
||||||
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
||||||
@@ -122,12 +121,6 @@ class OpenIE:
|
|||||||
openie_data = OpenIE._from_dict(data_list)
|
openie_data = OpenIE._from_dict(data_list)
|
||||||
return openie_data
|
return openie_data
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def save(openie_data: "OpenIE"):
|
|
||||||
"""保存OpenIE数据到文件"""
|
|
||||||
with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
|
|
||||||
|
|
||||||
def extract_entity_dict(self):
|
def extract_entity_dict(self):
|
||||||
"""提取实体列表"""
|
"""提取实体列表"""
|
||||||
ner_output_dict = dict(
|
ner_output_dict = dict(
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ from .global_logger import logger
|
|||||||
|
|
||||||
# from . import prompt_template
|
# from . import prompt_template
|
||||||
from .embedding_store import EmbeddingManager
|
from .embedding_store import EmbeddingManager
|
||||||
from .llm_client import LLMClient
|
# from .llm_client import LLMClient
|
||||||
from .kg_manager import KGManager
|
from .kg_manager import KGManager
|
||||||
from .lpmmconfig import global_config
|
# from .lpmmconfig import global_config
|
||||||
from .utils.dyn_topk import dyn_select_top_k
|
from .utils.dyn_topk import dyn_select_top_k
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.chat.utils.utils import get_embedding
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||||
|
|
||||||
@@ -19,26 +21,25 @@ class QAManager:
|
|||||||
self,
|
self,
|
||||||
embed_manager: EmbeddingManager,
|
embed_manager: EmbeddingManager,
|
||||||
kg_manager: KGManager,
|
kg_manager: KGManager,
|
||||||
llm_client_embedding: LLMClient,
|
|
||||||
llm_client_filter: LLMClient,
|
|
||||||
llm_client_qa: LLMClient,
|
|
||||||
):
|
):
|
||||||
self.embed_manager = embed_manager
|
self.embed_manager = embed_manager
|
||||||
self.kg_manager = kg_manager
|
self.kg_manager = kg_manager
|
||||||
self.llm_client_list = {
|
# TODO: API-Adapter修改标记
|
||||||
"embedding": llm_client_embedding,
|
self.qa_model = LLMRequest(
|
||||||
"message_filter": llm_client_filter,
|
model=global_config.model.lpmm_qa,
|
||||||
"qa": llm_client_qa,
|
request_type="lpmm.qa"
|
||||||
}
|
)
|
||||||
|
|
||||||
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
||||||
"""处理查询"""
|
"""处理查询"""
|
||||||
|
|
||||||
# 生成问题的Embedding
|
# 生成问题的Embedding
|
||||||
part_start_time = time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
|
question_embedding = get_embedding(question)
|
||||||
global_config["embedding"]["model"], question
|
if question_embedding is None:
|
||||||
)
|
logger.error("生成问题Embedding失败")
|
||||||
|
return None
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|
||||||
@@ -46,14 +47,15 @@ class QAManager:
|
|||||||
part_start_time = time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
||||||
question_embedding,
|
question_embedding,
|
||||||
global_config["qa"]["params"]["relation_search_top_k"],
|
global_config.lpmm_knowledge.qa_relation_search_top_k,
|
||||||
)
|
)
|
||||||
if relation_search_res is not None:
|
if relation_search_res is not None:
|
||||||
# 过滤阈值
|
# 过滤阈值
|
||||||
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
||||||
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
||||||
if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
|
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
|
||||||
# 未找到相关关系
|
# 未找到相关关系
|
||||||
|
logger.debug("未找到相关关系,跳过关系检索")
|
||||||
relation_search_res = []
|
relation_search_res = []
|
||||||
|
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
@@ -71,7 +73,7 @@ class QAManager:
|
|||||||
part_start_time = time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
||||||
question_embedding,
|
question_embedding,
|
||||||
global_config["qa"]["params"]["paragraph_search_top_k"],
|
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
|
||||||
)
|
)
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|||||||
@@ -38,7 +38,9 @@ class HeartFCSender:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.storage = MessageStorage()
|
self.storage = MessageStorage()
|
||||||
|
|
||||||
async def send_message(self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True):
|
async def send_message(
|
||||||
|
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
处理、发送并存储一条消息。
|
处理、发送并存储一条消息。
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,9 @@ class ActionPlanner:
|
|||||||
|
|
||||||
self.last_obs_time_mark = 0.0
|
self.last_obs_time_mark = 0.0
|
||||||
|
|
||||||
async def plan(self, mode: ChatMode = ChatMode.FOCUS) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension
|
async def plan(
|
||||||
|
self, mode: ChatMode = ChatMode.FOCUS
|
||||||
|
) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension
|
||||||
"""
|
"""
|
||||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -479,18 +479,18 @@ class DefaultReplyer:
|
|||||||
def build_s4u_chat_history_prompts(self, message_list_before_now: list, target_user_id: str) -> tuple[str, str]:
|
def build_s4u_chat_history_prompts(self, message_list_before_now: list, target_user_id: str) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
构建 s4u 风格的分离对话 prompt
|
构建 s4u 风格的分离对话 prompt
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message_list_before_now: 历史消息列表
|
message_list_before_now: 历史消息列表
|
||||||
target_user_id: 目标用户ID(当前对话对象)
|
target_user_id: 目标用户ID(当前对话对象)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: (核心对话prompt, 背景对话prompt)
|
tuple: (核心对话prompt, 背景对话prompt)
|
||||||
"""
|
"""
|
||||||
core_dialogue_list = []
|
core_dialogue_list = []
|
||||||
background_dialogue_list = []
|
background_dialogue_list = []
|
||||||
bot_id = str(global_config.bot.qq_account)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
|
|
||||||
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
# 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话
|
||||||
for msg_dict in message_list_before_now:
|
for msg_dict in message_list_before_now:
|
||||||
try:
|
try:
|
||||||
@@ -503,11 +503,11 @@ class DefaultReplyer:
|
|||||||
background_dialogue_list.append(msg_dict)
|
background_dialogue_list.append(msg_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
||||||
|
|
||||||
# 构建背景对话 prompt
|
# 构建背景对话 prompt
|
||||||
background_dialogue_prompt = ""
|
background_dialogue_prompt = ""
|
||||||
if background_dialogue_list:
|
if background_dialogue_list:
|
||||||
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size*0.6):]
|
latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :]
|
||||||
background_dialogue_prompt_str = build_readable_messages(
|
background_dialogue_prompt_str = build_readable_messages(
|
||||||
latest_25_msgs,
|
latest_25_msgs,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
@@ -516,12 +516,12 @@ class DefaultReplyer:
|
|||||||
show_pic=False,
|
show_pic=False,
|
||||||
)
|
)
|
||||||
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
|
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
|
||||||
|
|
||||||
# 构建核心对话 prompt
|
# 构建核心对话 prompt
|
||||||
core_dialogue_prompt = ""
|
core_dialogue_prompt = ""
|
||||||
if core_dialogue_list:
|
if core_dialogue_list:
|
||||||
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size*2):] # 限制消息数量
|
core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量
|
||||||
|
|
||||||
core_dialogue_prompt_str = build_readable_messages(
|
core_dialogue_prompt_str = build_readable_messages(
|
||||||
core_dialogue_list,
|
core_dialogue_list,
|
||||||
replace_bot_name=True,
|
replace_bot_name=True,
|
||||||
@@ -532,7 +532,7 @@ class DefaultReplyer:
|
|||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
core_dialogue_prompt = core_dialogue_prompt_str
|
core_dialogue_prompt = core_dialogue_prompt_str
|
||||||
|
|
||||||
return core_dialogue_prompt, background_dialogue_prompt
|
return core_dialogue_prompt, background_dialogue_prompt
|
||||||
|
|
||||||
async def build_prompt_reply_context(
|
async def build_prompt_reply_context(
|
||||||
@@ -578,14 +578,13 @@ class DefaultReplyer:
|
|||||||
action_description = action_info.description
|
action_description = action_info.description
|
||||||
action_descriptions += f"- {action_name}: {action_description}\n"
|
action_descriptions += f"- {action_name}: {action_description}\n"
|
||||||
action_descriptions += "\n"
|
action_descriptions += "\n"
|
||||||
|
|
||||||
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=global_config.chat.max_context_size * 2,
|
limit=global_config.chat.max_context_size * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
@@ -712,8 +711,6 @@ class DefaultReplyer:
|
|||||||
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
||||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 根据配置选择使用哪种 prompt 构建模式
|
# 根据配置选择使用哪种 prompt 构建模式
|
||||||
if global_config.chat.use_s4u_prompt_mode and person_id:
|
if global_config.chat.use_s4u_prompt_mode and person_id:
|
||||||
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
||||||
@@ -724,16 +721,15 @@ class DefaultReplyer:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
||||||
target_user_id = ""
|
target_user_id = ""
|
||||||
|
|
||||||
|
|
||||||
# 构建分离的对话 prompt
|
# 构建分离的对话 prompt
|
||||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||||
message_list_before_now_long, target_user_id
|
message_list_before_now_long, target_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 使用 s4u 风格的模板
|
# 使用 s4u 风格的模板
|
||||||
template_name = "s4u_style_prompt"
|
template_name = "s4u_style_prompt"
|
||||||
|
|
||||||
return await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt(
|
||||||
template_name,
|
template_name,
|
||||||
expression_habits_block=expression_habits_block,
|
expression_habits_block=expression_habits_block,
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ class PersonalityConfig(ConfigBase):
|
|||||||
|
|
||||||
personality_side: str
|
personality_side: str
|
||||||
"""人格侧写"""
|
"""人格侧写"""
|
||||||
|
|
||||||
identity: str = ""
|
identity: str = ""
|
||||||
"""身份特征"""
|
"""身份特征"""
|
||||||
|
|
||||||
@@ -106,7 +106,6 @@ class ChatConfig(ConfigBase):
|
|||||||
focus_value: float = 1.0
|
focus_value: float = 1.0
|
||||||
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
"""麦麦的专注思考能力,越低越容易专注,消耗token也越多"""
|
||||||
|
|
||||||
|
|
||||||
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float:
|
||||||
"""
|
"""
|
||||||
根据当前时间和聊天流获取对应的 talk_frequency
|
根据当前时间和聊天流获取对应的 talk_frequency
|
||||||
@@ -246,6 +245,7 @@ class ChatConfig(ConfigBase):
|
|||||||
except (ValueError, IndexError):
|
except (ValueError, IndexError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MessageReceiveConfig(ConfigBase):
|
class MessageReceiveConfig(ConfigBase):
|
||||||
"""消息接收配置类"""
|
"""消息接收配置类"""
|
||||||
@@ -274,8 +274,6 @@ class NormalChatConfig(ConfigBase):
|
|||||||
"""@bot 必然回复"""
|
"""@bot 必然回复"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExpressionConfig(ConfigBase):
|
class ExpressionConfig(ConfigBase):
|
||||||
"""表达配置类"""
|
"""表达配置类"""
|
||||||
@@ -627,3 +625,12 @@ class ModelConfig(ConfigBase):
|
|||||||
|
|
||||||
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
"""嵌入模型配置"""
|
"""嵌入模型配置"""
|
||||||
|
|
||||||
|
lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""LPMM实体提取模型配置"""
|
||||||
|
|
||||||
|
lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""LPMM RDF构建模型配置"""
|
||||||
|
|
||||||
|
lpmm_qa: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""LPMM问答模型配置"""
|
||||||
|
|||||||
@@ -41,11 +41,11 @@ class Individuality:
|
|||||||
personality_side: 人格侧面描述
|
personality_side: 人格侧面描述
|
||||||
identity: 身份细节描述
|
identity: 身份细节描述
|
||||||
"""
|
"""
|
||||||
bot_nickname=global_config.bot.nickname
|
bot_nickname = global_config.bot.nickname
|
||||||
personality_core=global_config.personality.personality_core
|
personality_core = global_config.personality.personality_core
|
||||||
personality_side=global_config.personality.personality_side
|
personality_side = global_config.personality.personality_side
|
||||||
identity=global_config.personality.identity
|
identity = global_config.personality.identity
|
||||||
|
|
||||||
logger.info("正在初始化个体特征")
|
logger.info("正在初始化个体特征")
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||||
@@ -146,11 +146,10 @@ class Individuality:
|
|||||||
else:
|
else:
|
||||||
logger.error("人设构建失败")
|
logger.error("人设构建失败")
|
||||||
|
|
||||||
|
|
||||||
async def get_personality_block(self) -> str:
|
async def get_personality_block(self) -> str:
|
||||||
person_info_manager = get_person_info_manager()
|
person_info_manager = get_person_info_manager()
|
||||||
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||||
|
|
||||||
bot_name = global_config.bot.nickname
|
bot_name = global_config.bot.nickname
|
||||||
if global_config.bot.alias_names:
|
if global_config.bot.alias_names:
|
||||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||||
@@ -175,9 +174,8 @@ class Individuality:
|
|||||||
identity = short_impression[1]
|
identity = short_impression[1]
|
||||||
prompt_personality = f"{personality},{identity}"
|
prompt_personality = f"{personality},{identity}"
|
||||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
|
|
||||||
return identity_block
|
|
||||||
|
|
||||||
|
return identity_block
|
||||||
|
|
||||||
def _get_config_hash(
|
def _get_config_hash(
|
||||||
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
self, bot_nickname: str, personality_core: str, personality_side: str, identity: list
|
||||||
@@ -273,7 +271,6 @@ class Individuality:
|
|||||||
except IOError as e:
|
except IOError as e:
|
||||||
logger.error(f"保存meta_info文件失败: {e}")
|
logger.error(f"保存meta_info文件失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _create_personality(self, personality_core: str, personality_side: str) -> str:
|
async def _create_personality(self, personality_core: str, personality_side: str) -> str:
|
||||||
# sourcery skip: merge-list-append, move-assign
|
# sourcery skip: merge-list-append, move-assign
|
||||||
"""使用LLM创建压缩版本的impression
|
"""使用LLM创建压缩版本的impression
|
||||||
|
|||||||
@@ -40,7 +40,15 @@ class Personality:
|
|||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize(cls, bot_nickname: str, personality_core: str, personality_side: str, identity: List[str] = None, compress_personality: bool = True, compress_identity: bool = True) -> "Personality":
|
def initialize(
|
||||||
|
cls,
|
||||||
|
bot_nickname: str,
|
||||||
|
personality_core: str,
|
||||||
|
personality_side: str,
|
||||||
|
identity: List[str] = None,
|
||||||
|
compress_personality: bool = True,
|
||||||
|
compress_identity: bool = True,
|
||||||
|
) -> "Personality":
|
||||||
"""初始化人格特质
|
"""初始化人格特质
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
|
|
||||||
from src.plugin_system.apis import send_api
|
from src.plugin_system.apis import send_api
|
||||||
|
|
||||||
|
|
||||||
async def send_loading(chat_id: str, content: str):
|
async def send_loading(chat_id: str, content: str):
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="loading",
|
message_type="loading",
|
||||||
@@ -9,7 +10,8 @@ async def send_loading(chat_id: str, content: str):
|
|||||||
storage_message=False,
|
storage_message=False,
|
||||||
show_log=True,
|
show_log=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def send_unloading(chat_id: str):
|
async def send_unloading(chat_id: str):
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="loading",
|
message_type="loading",
|
||||||
@@ -18,4 +20,3 @@ async def send_unloading(chat_id: str):
|
|||||||
storage_message=False,
|
storage_message=False,
|
||||||
show_log=True,
|
show_log=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ class ChatMood:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.last_change_time = 0
|
self.last_change_time = 0
|
||||||
|
|
||||||
# 发送初始情绪状态到ws端
|
# 发送初始情绪状态到ws端
|
||||||
asyncio.create_task(self.send_emotion_update(self.mood_values))
|
asyncio.create_task(self.send_emotion_update(self.mood_values))
|
||||||
|
|
||||||
@@ -231,10 +231,10 @@ class ChatMood:
|
|||||||
if numerical_mood_response:
|
if numerical_mood_response:
|
||||||
_old_mood_values = self.mood_values.copy()
|
_old_mood_values = self.mood_values.copy()
|
||||||
self.mood_values = numerical_mood_response
|
self.mood_values = numerical_mood_response
|
||||||
|
|
||||||
# 发送情绪更新到ws端
|
# 发送情绪更新到ws端
|
||||||
await self.send_emotion_update(self.mood_values)
|
await self.send_emotion_update(self.mood_values)
|
||||||
|
|
||||||
logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}")
|
logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}")
|
||||||
|
|
||||||
self.last_change_time = message_time
|
self.last_change_time = message_time
|
||||||
@@ -308,10 +308,10 @@ class ChatMood:
|
|||||||
if numerical_mood_response:
|
if numerical_mood_response:
|
||||||
_old_mood_values = self.mood_values.copy()
|
_old_mood_values = self.mood_values.copy()
|
||||||
self.mood_values = numerical_mood_response
|
self.mood_values = numerical_mood_response
|
||||||
|
|
||||||
# 发送情绪更新到ws端
|
# 发送情绪更新到ws端
|
||||||
await self.send_emotion_update(self.mood_values)
|
await self.send_emotion_update(self.mood_values)
|
||||||
|
|
||||||
logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}")
|
logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}")
|
||||||
|
|
||||||
self.regression_count += 1
|
self.regression_count += 1
|
||||||
@@ -322,9 +322,9 @@ class ChatMood:
|
|||||||
"joy": mood_values.get("joy", 5),
|
"joy": mood_values.get("joy", 5),
|
||||||
"anger": mood_values.get("anger", 1),
|
"anger": mood_values.get("anger", 1),
|
||||||
"sorrow": mood_values.get("sorrow", 1),
|
"sorrow": mood_values.get("sorrow", 1),
|
||||||
"fear": mood_values.get("fear", 1)
|
"fear": mood_values.get("fear", 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="emotion",
|
message_type="emotion",
|
||||||
content=emotion_data,
|
content=emotion_data,
|
||||||
@@ -332,7 +332,7 @@ class ChatMood:
|
|||||||
storage_message=False,
|
storage_message=False,
|
||||||
show_log=True,
|
show_log=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}")
|
logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}")
|
||||||
|
|
||||||
|
|
||||||
@@ -345,27 +345,27 @@ class MoodRegressionTask(AsyncTask):
|
|||||||
async def run(self):
|
async def run(self):
|
||||||
self.run_count += 1
|
self.run_count += 1
|
||||||
logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态")
|
logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态")
|
||||||
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
regression_executed = 0
|
regression_executed = 0
|
||||||
|
|
||||||
for mood in self.mood_manager.mood_list:
|
for mood in self.mood_manager.mood_list:
|
||||||
chat_info = f"chat {mood.chat_id}"
|
chat_info = f"chat {mood.chat_id}"
|
||||||
|
|
||||||
if mood.last_change_time == 0:
|
if mood.last_change_time == 0:
|
||||||
logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归")
|
logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
time_since_last_change = now - mood.last_change_time
|
time_since_last_change = now - mood.last_change_time
|
||||||
|
|
||||||
# 检查是否有极端情绪需要快速回归
|
# 检查是否有极端情绪需要快速回归
|
||||||
high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8}
|
high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8}
|
||||||
has_extreme_emotion = len(high_emotions) > 0
|
has_extreme_emotion = len(high_emotions) > 0
|
||||||
|
|
||||||
# 回归条件:1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s
|
# 回归条件:1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s
|
||||||
should_regress = False
|
should_regress = False
|
||||||
regress_reason = ""
|
regress_reason = ""
|
||||||
|
|
||||||
if time_since_last_change > 120:
|
if time_since_last_change > 120:
|
||||||
should_regress = True
|
should_regress = True
|
||||||
regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)"
|
regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)"
|
||||||
@@ -373,24 +373,28 @@ class MoodRegressionTask(AsyncTask):
|
|||||||
should_regress = True
|
should_regress = True
|
||||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||||
regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)"
|
regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)"
|
||||||
|
|
||||||
if should_regress:
|
if should_regress:
|
||||||
if mood.regression_count >= 3:
|
if mood.regression_count >= 3:
|
||||||
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
|
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)")
|
logger.info(
|
||||||
|
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
|
||||||
|
)
|
||||||
await mood.regress_mood()
|
await mood.regress_mood()
|
||||||
regression_executed += 1
|
regression_executed += 1
|
||||||
else:
|
else:
|
||||||
if has_extreme_emotion:
|
if has_extreme_emotion:
|
||||||
remaining_time = 5 - time_since_last_change
|
remaining_time = 5 - time_since_last_change
|
||||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||||
logger.debug(f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒")
|
logger.debug(
|
||||||
|
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
remaining_time = 120 - time_since_last_change
|
remaining_time = 120 - time_since_last_change
|
||||||
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒")
|
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒")
|
||||||
|
|
||||||
if regression_executed > 0:
|
if regression_executed > 0:
|
||||||
logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归")
|
logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归")
|
||||||
else:
|
else:
|
||||||
@@ -409,11 +413,11 @@ class MoodManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("启动情绪管理任务...")
|
logger.info("启动情绪管理任务...")
|
||||||
|
|
||||||
# 启动情绪回归任务
|
# 启动情绪回归任务
|
||||||
regression_task = MoodRegressionTask(self)
|
regression_task = MoodRegressionTask(self)
|
||||||
await async_task_manager.add_task(regression_task)
|
await async_task_manager.add_task(regression_task)
|
||||||
|
|
||||||
self.task_started = True
|
self.task_started = True
|
||||||
logger.info("情绪管理任务已启动(情绪回归)")
|
logger.info("情绪管理任务已启动(情绪回归)")
|
||||||
|
|
||||||
@@ -435,7 +439,7 @@ class MoodManager:
|
|||||||
# 发送重置后的情绪状态到ws端
|
# 发送重置后的情绪状态到ws端
|
||||||
asyncio.create_task(mood.send_emotion_update(mood.mood_values))
|
asyncio.create_task(mood.send_emotion_update(mood.mood_values))
|
||||||
return
|
return
|
||||||
|
|
||||||
# 如果没有找到现有的mood,创建新的
|
# 如果没有找到现有的mood,创建新的
|
||||||
new_mood = ChatMood(chat_id)
|
new_mood = ChatMood(chat_id)
|
||||||
self.mood_list.append(new_mood)
|
self.mood_list.append(new_mood)
|
||||||
|
|||||||
@@ -107,7 +107,6 @@ class S4UStreamGenerator:
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
|
|
||||||
buffer = ""
|
buffer = ""
|
||||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||||||
punctuation_buffer = ""
|
punctuation_buffer = ""
|
||||||
|
|||||||
@@ -43,23 +43,24 @@ logger = get_logger("watching")
|
|||||||
|
|
||||||
class WatchingState(Enum):
|
class WatchingState(Enum):
|
||||||
"""视线状态枚举"""
|
"""视线状态枚举"""
|
||||||
|
|
||||||
WANDERING = "wandering" # 随意看
|
WANDERING = "wandering" # 随意看
|
||||||
DANMU = "danmu" # 看弹幕
|
DANMU = "danmu" # 看弹幕
|
||||||
LENS = "lens" # 看镜头
|
LENS = "lens" # 看镜头
|
||||||
|
|
||||||
|
|
||||||
class ChatWatching:
|
class ChatWatching:
|
||||||
def __init__(self, chat_id: str):
|
def __init__(self, chat_id: str):
|
||||||
self.chat_id: str = chat_id
|
self.chat_id: str = chat_id
|
||||||
self.current_state: WatchingState = WatchingState.LENS # 默认看镜头
|
self.current_state: WatchingState = WatchingState.LENS # 默认看镜头
|
||||||
self.last_sent_state: Optional[WatchingState] = None # 上次发送的状态
|
self.last_sent_state: Optional[WatchingState] = None # 上次发送的状态
|
||||||
self.state_needs_update: bool = True # 是否需要更新状态
|
self.state_needs_update: bool = True # 是否需要更新状态
|
||||||
|
|
||||||
# 状态切换相关
|
# 状态切换相关
|
||||||
self.is_replying: bool = False # 是否正在生成回复
|
self.is_replying: bool = False # 是否正在生成回复
|
||||||
self.reply_finished_time: Optional[float] = None # 回复完成时间
|
self.reply_finished_time: Optional[float] = None # 回复完成时间
|
||||||
self.danmu_viewing_duration: float = 1.0 # 看弹幕持续时间(秒)
|
self.danmu_viewing_duration: float = 1.0 # 看弹幕持续时间(秒)
|
||||||
|
|
||||||
logger.info(f"[{self.chat_id}] 视线管理器初始化,默认状态: {self.current_state.value}")
|
logger.info(f"[{self.chat_id}] 视线管理器初始化,默认状态: {self.current_state.value}")
|
||||||
|
|
||||||
async def _change_state(self, new_state: WatchingState, reason: str = ""):
|
async def _change_state(self, new_state: WatchingState, reason: str = ""):
|
||||||
@@ -69,7 +70,7 @@ class ChatWatching:
|
|||||||
self.current_state = new_state
|
self.current_state = new_state
|
||||||
self.state_needs_update = True
|
self.state_needs_update = True
|
||||||
logger.info(f"[{self.chat_id}] 视线状态切换: {old_state.value} → {new_state.value} ({reason})")
|
logger.info(f"[{self.chat_id}] 视线状态切换: {old_state.value} → {new_state.value} ({reason})")
|
||||||
|
|
||||||
# 立即发送视线状态更新
|
# 立即发送视线状态更新
|
||||||
await self._send_watching_update()
|
await self._send_watching_update()
|
||||||
else:
|
else:
|
||||||
@@ -86,7 +87,7 @@ class ChatWatching:
|
|||||||
"""开始生成回复时调用"""
|
"""开始生成回复时调用"""
|
||||||
self.is_replying = True
|
self.is_replying = True
|
||||||
self.reply_finished_time = None
|
self.reply_finished_time = None
|
||||||
|
|
||||||
if look_at_lens:
|
if look_at_lens:
|
||||||
await self._change_state(WatchingState.LENS, "开始生成回复-看镜头")
|
await self._change_state(WatchingState.LENS, "开始生成回复-看镜头")
|
||||||
else:
|
else:
|
||||||
@@ -96,35 +97,29 @@ class ChatWatching:
|
|||||||
"""生成回复完毕时调用"""
|
"""生成回复完毕时调用"""
|
||||||
self.is_replying = False
|
self.is_replying = False
|
||||||
self.reply_finished_time = time.time()
|
self.reply_finished_time = time.time()
|
||||||
|
|
||||||
# 先看弹幕1秒
|
# 先看弹幕1秒
|
||||||
await self._change_state(WatchingState.DANMU, "回复完毕-看弹幕")
|
await self._change_state(WatchingState.DANMU, "回复完毕-看弹幕")
|
||||||
logger.info(f"[{self.chat_id}] 回复完毕,将看弹幕{self.danmu_viewing_duration}秒后转为看镜头")
|
logger.info(f"[{self.chat_id}] 回复完毕,将看弹幕{self.danmu_viewing_duration}秒后转为看镜头")
|
||||||
|
|
||||||
# 设置定时器,1秒后自动切换到看镜头
|
# 设置定时器,1秒后自动切换到看镜头
|
||||||
asyncio.create_task(self._auto_switch_to_lens())
|
asyncio.create_task(self._auto_switch_to_lens())
|
||||||
|
|
||||||
async def _auto_switch_to_lens(self):
|
async def _auto_switch_to_lens(self):
|
||||||
"""自动切换到看镜头(延迟执行)"""
|
"""自动切换到看镜头(延迟执行)"""
|
||||||
await asyncio.sleep(self.danmu_viewing_duration)
|
await asyncio.sleep(self.danmu_viewing_duration)
|
||||||
|
|
||||||
# 检查是否仍需要切换(可能状态已经被其他事件改变)
|
# 检查是否仍需要切换(可能状态已经被其他事件改变)
|
||||||
if (self.reply_finished_time is not None and
|
if self.reply_finished_time is not None and self.current_state == WatchingState.DANMU and not self.is_replying:
|
||||||
self.current_state == WatchingState.DANMU and
|
|
||||||
not self.is_replying):
|
|
||||||
|
|
||||||
await self._change_state(WatchingState.LENS, "看弹幕时间结束")
|
await self._change_state(WatchingState.LENS, "看弹幕时间结束")
|
||||||
self.reply_finished_time = None # 重置完成时间
|
self.reply_finished_time = None # 重置完成时间
|
||||||
|
|
||||||
async def _send_watching_update(self):
|
async def _send_watching_update(self):
|
||||||
"""立即发送视线状态更新"""
|
"""立即发送视线状态更新"""
|
||||||
await send_api.custom_to_stream(
|
await send_api.custom_to_stream(
|
||||||
message_type="watching",
|
message_type="watching", content=self.current_state.value, stream_id=self.chat_id, storage_message=False
|
||||||
content=self.current_state.value,
|
|
||||||
stream_id=self.chat_id,
|
|
||||||
storage_message=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"[{self.chat_id}] 发送视线状态更新: {self.current_state.value}")
|
logger.info(f"[{self.chat_id}] 发送视线状态更新: {self.current_state.value}")
|
||||||
self.last_sent_state = self.current_state
|
self.last_sent_state = self.current_state
|
||||||
self.state_needs_update = False
|
self.state_needs_update = False
|
||||||
@@ -139,11 +134,10 @@ class ChatWatching:
|
|||||||
"current_state": self.current_state.value,
|
"current_state": self.current_state.value,
|
||||||
"is_replying": self.is_replying,
|
"is_replying": self.is_replying,
|
||||||
"reply_finished_time": self.reply_finished_time,
|
"reply_finished_time": self.reply_finished_time,
|
||||||
"state_needs_update": self.state_needs_update
|
"state_needs_update": self.state_needs_update,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WatchingManager:
|
class WatchingManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.watching_list: list[ChatWatching] = []
|
self.watching_list: list[ChatWatching] = []
|
||||||
@@ -156,7 +150,7 @@ class WatchingManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("启动视线管理系统...")
|
logger.info("启动视线管理系统...")
|
||||||
|
|
||||||
self.task_started = True
|
self.task_started = True
|
||||||
logger.info("视线管理系统已启动(状态变化时立即发送)")
|
logger.info("视线管理系统已启动(状态变化时立即发送)")
|
||||||
|
|
||||||
@@ -169,10 +163,10 @@ class WatchingManager:
|
|||||||
new_watching = ChatWatching(chat_id)
|
new_watching = ChatWatching(chat_id)
|
||||||
self.watching_list.append(new_watching)
|
self.watching_list.append(new_watching)
|
||||||
logger.info(f"为chat {chat_id}创建新的视线管理器")
|
logger.info(f"为chat {chat_id}创建新的视线管理器")
|
||||||
|
|
||||||
# 发送初始状态
|
# 发送初始状态
|
||||||
asyncio.create_task(new_watching._send_watching_update())
|
asyncio.create_task(new_watching._send_watching_update())
|
||||||
|
|
||||||
return new_watching
|
return new_watching
|
||||||
|
|
||||||
def reset_watching_by_chat_id(self, chat_id: str):
|
def reset_watching_by_chat_id(self, chat_id: str):
|
||||||
@@ -185,27 +179,24 @@ class WatchingManager:
|
|||||||
watching.is_replying = False
|
watching.is_replying = False
|
||||||
watching.reply_finished_time = None
|
watching.reply_finished_time = None
|
||||||
logger.info(f"[{chat_id}] 视线状态已重置为默认状态")
|
logger.info(f"[{chat_id}] 视线状态已重置为默认状态")
|
||||||
|
|
||||||
# 发送重置后的状态
|
# 发送重置后的状态
|
||||||
asyncio.create_task(watching._send_watching_update())
|
asyncio.create_task(watching._send_watching_update())
|
||||||
return
|
return
|
||||||
|
|
||||||
# 如果没有找到现有的watching,创建新的
|
# 如果没有找到现有的watching,创建新的
|
||||||
new_watching = ChatWatching(chat_id)
|
new_watching = ChatWatching(chat_id)
|
||||||
self.watching_list.append(new_watching)
|
self.watching_list.append(new_watching)
|
||||||
logger.info(f"为chat {chat_id}创建并重置视线管理器")
|
logger.info(f"为chat {chat_id}创建并重置视线管理器")
|
||||||
|
|
||||||
# 发送初始状态
|
# 发送初始状态
|
||||||
asyncio.create_task(new_watching._send_watching_update())
|
asyncio.create_task(new_watching._send_watching_update())
|
||||||
|
|
||||||
def get_all_watching_info(self) -> dict:
|
def get_all_watching_info(self) -> dict:
|
||||||
"""获取所有聊天的视线状态信息(用于调试)"""
|
"""获取所有聊天的视线状态信息(用于调试)"""
|
||||||
return {
|
return {watching.chat_id: watching.get_state_info() for watching in self.watching_list}
|
||||||
watching.chat_id: watching.get_state_info()
|
|
||||||
for watching in self.watching_list
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局视线管理器实例
|
# 全局视线管理器实例
|
||||||
watching_manager = WatchingManager()
|
watching_manager = WatchingManager()
|
||||||
"""全局视线管理器"""
|
"""全局视线管理器"""
|
||||||
|
|||||||
@@ -46,10 +46,10 @@ def init_prompt():
|
|||||||
class ChatMood:
|
class ChatMood:
|
||||||
def __init__(self, chat_id: str):
|
def __init__(self, chat_id: str):
|
||||||
self.chat_id: str = chat_id
|
self.chat_id: str = chat_id
|
||||||
|
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
self.chat_stream = chat_manager.get_stream(self.chat_id)
|
self.chat_stream = chat_manager.get_stream(self.chat_id)
|
||||||
|
|
||||||
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"
|
self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]"
|
||||||
|
|
||||||
self.mood_state: str = "感觉很平静"
|
self.mood_state: str = "感觉很平静"
|
||||||
@@ -92,7 +92,7 @@ class ChatMood:
|
|||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp_start=self.last_change_time,
|
timestamp_start=self.last_change_time,
|
||||||
timestamp_end=message_time,
|
timestamp_end=message_time,
|
||||||
limit=int(global_config.chat.max_context_size/3),
|
limit=int(global_config.chat.max_context_size / 3),
|
||||||
limit_mode="last",
|
limit_mode="last",
|
||||||
)
|
)
|
||||||
chat_talking_prompt = build_readable_messages(
|
chat_talking_prompt = build_readable_messages(
|
||||||
@@ -121,14 +121,12 @@ class ChatMood:
|
|||||||
mood_state=self.mood_state,
|
mood_state=self.mood_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
||||||
logger.info(f"{self.log_prefix} response: {response}")
|
logger.info(f"{self.log_prefix} response: {response}")
|
||||||
logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 情绪状态更新为: {response}")
|
logger.info(f"{self.log_prefix} 情绪状态更新为: {response}")
|
||||||
|
|
||||||
self.mood_state = response
|
self.mood_state = response
|
||||||
@@ -170,15 +168,14 @@ class ChatMood:
|
|||||||
mood_state=self.mood_state,
|
mood_state=self.mood_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
if global_config.debug.show_prompt:
|
if global_config.debug.show_prompt:
|
||||||
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
logger.info(f"{self.log_prefix} prompt: {prompt}")
|
||||||
logger.info(f"{self.log_prefix} response: {response}")
|
logger.info(f"{self.log_prefix} response: {response}")
|
||||||
logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}")
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 情绪状态回归为: {response}")
|
logger.info(f"{self.log_prefix} 情绪状态回归为: {response}")
|
||||||
|
|
||||||
self.mood_state = response
|
self.mood_state = response
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,12 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[ChatStream]: 聊天流列表
|
List[ChatStream]: 聊天流列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
streams = []
|
streams = []
|
||||||
try:
|
try:
|
||||||
for _, stream in get_chat_manager().streams.items():
|
for _, stream in get_chat_manager().streams.items():
|
||||||
@@ -60,6 +65,8 @@ class ChatManager:
|
|||||||
Returns:
|
Returns:
|
||||||
List[ChatStream]: 群聊聊天流列表
|
List[ChatStream]: 群聊聊天流列表
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
streams = []
|
streams = []
|
||||||
try:
|
try:
|
||||||
for _, stream in get_chat_manager().streams.items():
|
for _, stream in get_chat_manager().streams.items():
|
||||||
@@ -79,7 +86,12 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[ChatStream]: 私聊聊天流列表
|
List[ChatStream]: 私聊聊天流列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
streams = []
|
streams = []
|
||||||
try:
|
try:
|
||||||
for _, stream in get_chat_manager().streams.items():
|
for _, stream in get_chat_manager().streams.items():
|
||||||
@@ -102,7 +114,17 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果 group_id 为空字符串
|
||||||
|
TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(group_id, str):
|
||||||
|
raise TypeError("group_id 必须是字符串类型")
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
|
if not group_id:
|
||||||
|
raise ValueError("group_id 不能为空")
|
||||||
try:
|
try:
|
||||||
for _, stream in get_chat_manager().streams.items():
|
for _, stream in get_chat_manager().streams.items():
|
||||||
if (
|
if (
|
||||||
@@ -129,7 +151,17 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果 user_id 为空字符串
|
||||||
|
TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(user_id, str):
|
||||||
|
raise TypeError("user_id 必须是字符串类型")
|
||||||
|
if not isinstance(platform, (str, SpecialTypes)):
|
||||||
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
|
if not user_id:
|
||||||
|
raise ValueError("user_id 不能为空")
|
||||||
try:
|
try:
|
||||||
for _, stream in get_chat_manager().streams.items():
|
for _, stream in get_chat_manager().streams.items():
|
||||||
if (
|
if (
|
||||||
@@ -153,9 +185,15 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 聊天类型 ("group", "private", "unknown")
|
str: 聊天类型 ("group", "private", "unknown")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||||
|
ValueError: 如果 chat_stream 为空
|
||||||
"""
|
"""
|
||||||
|
if not isinstance(chat_stream, ChatStream):
|
||||||
|
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||||
if not chat_stream:
|
if not chat_stream:
|
||||||
raise ValueError("chat_stream cannot be None")
|
raise ValueError("chat_stream 不能为 None")
|
||||||
|
|
||||||
if hasattr(chat_stream, "group_info"):
|
if hasattr(chat_stream, "group_info"):
|
||||||
return "group" if chat_stream.group_info else "private"
|
return "group" if chat_stream.group_info else "private"
|
||||||
@@ -170,9 +208,15 @@ class ChatManager:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: 聊天流信息字典
|
Dict[str, Any]: 聊天流信息字典
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||||
|
ValueError: 如果 chat_stream 为空
|
||||||
"""
|
"""
|
||||||
if not chat_stream:
|
if not chat_stream:
|
||||||
return {}
|
raise ValueError("chat_stream 不能为 None")
|
||||||
|
if not isinstance(chat_stream, ChatStream):
|
||||||
|
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
info: Dict[str, Any] = {
|
info: Dict[str, Any] = {
|
||||||
|
|||||||
@@ -8,6 +8,8 @@
|
|||||||
count = emoji_api.get_count()
|
count = emoji_api.get_count()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
from typing import Optional, Tuple, List
|
from typing import Optional, Tuple, List
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||||
@@ -29,7 +31,15 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果描述为空字符串
|
||||||
|
TypeError: 如果描述不是字符串类型
|
||||||
"""
|
"""
|
||||||
|
if not description:
|
||||||
|
raise ValueError("描述不能为空")
|
||||||
|
if not isinstance(description, str):
|
||||||
|
raise TypeError("描述必须是字符串类型")
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
||||||
|
|
||||||
@@ -55,7 +65,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
|
async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, str]]]:
|
||||||
"""随机获取指定数量的表情包
|
"""随机获取指定数量的表情包
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -63,8 +73,17 @@ async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,如果失败则为None
|
Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,如果失败则为None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: 如果count不是整数类型
|
||||||
|
ValueError: 如果count为负数
|
||||||
"""
|
"""
|
||||||
if count <= 0:
|
if not isinstance(count, int):
|
||||||
|
raise TypeError("count 必须是整数类型")
|
||||||
|
if count < 0:
|
||||||
|
raise ValueError("count 不能为负数")
|
||||||
|
if count == 0:
|
||||||
|
logger.warning("[EmojiAPI] count 为0,返回空列表")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -90,8 +109,6 @@ async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]:
|
|||||||
count = len(valid_emojis)
|
count = len(valid_emojis)
|
||||||
|
|
||||||
# 随机选择
|
# 随机选择
|
||||||
import random
|
|
||||||
|
|
||||||
selected_emojis = random.sample(valid_emojis, count)
|
selected_emojis = random.sample(valid_emojis, count)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@@ -128,7 +145,15 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果情感标签为空字符串
|
||||||
|
TypeError: 如果情感标签不是字符串类型
|
||||||
"""
|
"""
|
||||||
|
if not emotion:
|
||||||
|
raise ValueError("情感标签不能为空")
|
||||||
|
if not isinstance(emotion, str):
|
||||||
|
raise TypeError("情感标签必须是字符串类型")
|
||||||
try:
|
try:
|
||||||
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
||||||
|
|
||||||
@@ -146,8 +171,6 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 随机选择匹配的表情包
|
# 随机选择匹配的表情包
|
||||||
import random
|
|
||||||
|
|
||||||
selected_emoji = random.choice(matching_emojis)
|
selected_emoji = random.choice(matching_emojis)
|
||||||
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
||||||
|
|
||||||
@@ -185,11 +208,11 @@ def get_count() -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def get_info() -> dict:
|
def get_info():
|
||||||
"""获取表情包系统信息
|
"""获取表情包系统信息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: 包含表情包数量、最大数量等信息
|
dict: 包含表情包数量、最大数量、可用数量信息
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
emoji_manager = get_emoji_manager()
|
emoji_manager = get_emoji_manager()
|
||||||
@@ -203,7 +226,7 @@ def get_info() -> dict:
|
|||||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||||
|
|
||||||
|
|
||||||
def get_emotions() -> list:
|
def get_emotions() -> List[str]:
|
||||||
"""获取所有可用的情感标签
|
"""获取所有可用的情感标签
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -223,7 +246,7 @@ def get_emotions() -> list:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def get_descriptions() -> list:
|
def get_descriptions() -> List[str]:
|
||||||
"""获取所有表情包描述
|
"""获取所有表情包描述
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|||||||
@@ -5,11 +5,12 @@
|
|||||||
使用方式:
|
使用方式:
|
||||||
from src.plugin_system.apis import generator_api
|
from src.plugin_system.apis import generator_api
|
||||||
replyer = generator_api.get_replyer(chat_stream)
|
replyer = generator_api.get_replyer(chat_stream)
|
||||||
success, reply_set = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Tuple, Any, Dict, List, Optional
|
from typing import Tuple, Any, Dict, List, Optional
|
||||||
|
from rich.traceback import install
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.replyer.default_generator import DefaultReplyer
|
from src.chat.replyer.default_generator import DefaultReplyer
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
@@ -17,6 +18,8 @@ from src.chat.utils.utils import process_llm_response
|
|||||||
from src.chat.replyer.replyer_manager import replyer_manager
|
from src.chat.replyer.replyer_manager import replyer_manager
|
||||||
from src.plugin_system.base.component_types import ActionInfo
|
from src.plugin_system.base.component_types import ActionInfo
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("generator_api")
|
logger = get_logger("generator_api")
|
||||||
|
|
||||||
|
|
||||||
@@ -44,7 +47,12 @@ def get_replyer(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None
|
Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: chat_stream 和 chat_id 均为空
|
||||||
"""
|
"""
|
||||||
|
if not chat_id and not chat_stream:
|
||||||
|
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||||
return replyer_manager.get_replyer(
|
return replyer_manager.get_replyer(
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from src.config.config import global_config
|
|||||||
|
|
||||||
logger = get_logger("llm_api")
|
logger = get_logger("llm_api")
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# LLM模型API函数
|
# LLM模型API函数
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -31,8 +30,21 @@ def get_available_models() -> Dict[str, Any]:
|
|||||||
logger.error("[LLMAPI] 无法获取模型列表:全局配置中未找到 model 配置")
|
logger.error("[LLMAPI] 无法获取模型列表:全局配置中未找到 model 配置")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
# 自动获取所有属性并转换为字典形式
|
||||||
|
rets = {}
|
||||||
models = global_config.model
|
models = global_config.model
|
||||||
return models
|
attrs = dir(models)
|
||||||
|
for attr in attrs:
|
||||||
|
if not attr.startswith("__"):
|
||||||
|
try:
|
||||||
|
value = getattr(models, attr)
|
||||||
|
if not callable(value): # 排除方法
|
||||||
|
rets[attr] = value
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
|
||||||
|
continue
|
||||||
|
return rets
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
|
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@@ -114,7 +114,11 @@ async def _send_to_target(
|
|||||||
|
|
||||||
# 发送消息
|
# 发送消息
|
||||||
sent_msg = await heart_fc_sender.send_message(
|
sent_msg = await heart_fc_sender.send_message(
|
||||||
bot_message, typing=typing, set_reply=(anchor_message is not None), storage_message=storage_message, show_log=show_log
|
bot_message,
|
||||||
|
typing=typing,
|
||||||
|
set_reply=(anchor_message is not None),
|
||||||
|
storage_message=storage_message,
|
||||||
|
show_log=show_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
if sent_msg:
|
if sent_msg:
|
||||||
@@ -362,7 +366,9 @@ async def custom_to_stream(
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否发送成功
|
bool: 是否发送成功
|
||||||
"""
|
"""
|
||||||
return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log)
|
return await _send_to_target(
|
||||||
|
message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def text_to_group(
|
async def text_to_group(
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class ReplyAction(BaseAction):
|
|||||||
|
|
||||||
reply_to = self.action_data.get("reply_to", "")
|
reply_to = self.action_data.get("reply_to", "")
|
||||||
sender, target = self._parse_reply_target(reply_to)
|
sender, target = self._parse_reply_target(reply_to)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prepared_reply = self.action_data.get("prepared_reply", "")
|
prepared_reply = self.action_data.get("prepared_reply", "")
|
||||||
if not prepared_reply:
|
if not prepared_reply:
|
||||||
|
|||||||
Reference in New Issue
Block a user