diff --git a/changes.md b/changes.md index 7ec499b43..4986e4d60 100644 --- a/changes.md +++ b/changes.md @@ -20,3 +20,4 @@ - `chat_api.py`中获取流的参数中可以使用一个特殊的枚举类型来获得所有平台的 ChatStream 了。 - `config_api.py`中的`get_global_config`和`get_plugin_config`方法现在支持嵌套访问的配置键名。 - `database_api.py`中的`db_query`方法调整了参数顺序以增强参数限制的同时,保证了typing正确;`db_get`方法增加了`single_result`参数,与`db_query`保持一致。 +4. 现在增加了参数类型检查,完善了对应注释 \ No newline at end of file diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 1a36fd240..791c64672 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -9,19 +9,17 @@ import os from time import sleep sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.open_ie import OpenIE from src.chat.knowledge.kg_manager import KGManager from src.common.logger import get_logger from src.chat.knowledge.utils.hash import get_sha256 +from src.manager.local_store_manager import local_storage # 添加项目根目录到 sys.path ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -OPENIE_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie") +OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") logger = get_logger("OpenIE导入") @@ -63,7 +61,7 @@ def hash_deduplicate( ): # 段落hash paragraph_hash = get_sha256(raw_paragraph) - if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: + if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: continue new_raw_paragraphs[paragraph_hash] = raw_paragraph new_triple_list_data[paragraph_hash] = triple_list @@ -193,15 +191,9 @@ def main(): # sourcery skip: dict-comprehension logger.info("----开始导入openie数据----\n") logger.info("创建LLM客户端") - llm_client_list = {} - for key in global_config["llm_providers"]: - llm_client_list[key] = LLMClient( - global_config["llm_providers"][key]["base_url"], - global_config["llm_providers"][key]["api_key"], - ) # 初始化Embedding库 - embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) + embed_manager = EmbeddingManager() logger.info("正在从文件加载Embedding库") try: embed_manager.load_from_file() @@ -230,7 +222,7 @@ def main(): # sourcery skip: dict-comprehension # 数据比对:Embedding库与KG的段落hash集合 for pg_hash in kg_manager.stored_paragraph_hashes: - key = f"{PG_NAMESPACE}-{pg_hash}" + key = f"{local_storage['pg_namespace']}-{pg_hash}" if key not in embed_manager.stored_pg_hashes: logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index b7e2b5592..7370d98d8 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -4,7 +4,6 @@ import signal from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock, Event import sys -import glob import datetime sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -13,11 +12,9 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from rich.progress import Progress # 替换为 rich 进度条 from src.common.logger import get_logger -from src.chat.knowledge.lpmmconfig import global_config +# from src.chat.knowledge.lpmmconfig import global_config from src.chat.knowledge.ie_process import info_extract_from_str -from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.open_ie import OpenIE -from src.chat.knowledge.raw_processing import load_raw_data from rich.progress import ( BarColumn, TimeElapsedColumn, @@ -27,24 +24,17 @@ from rich.progress import ( SpinnerColumn, TextColumn, ) +from raw_data_preprocessor import RAW_DATA_PATH, process_multi_files, load_raw_data +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("LPMM知识库-信息提取") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") -IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join( - ROOT_PATH, "data", "imported_lpmm_data" -) -OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie") - -# 创建一个线程安全的锁,用于保护文件操作和共享数据 -file_lock = Lock() -open_ie_doc_lock = Lock() - -# 创建一个事件标志,用于控制程序终止 -shutdown_event = Event() - +# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") +OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") def ensure_dirs(): """确保临时目录和输出目录存在""" @@ -54,12 +44,26 @@ def ensure_dirs(): if not os.path.exists(OPENIE_OUTPUT_DIR): os.makedirs(OPENIE_OUTPUT_DIR) logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}") - if not os.path.exists(IMPORTED_DATA_PATH): - os.makedirs(IMPORTED_DATA_PATH) - logger.info(f"已创建导入数据目录: {IMPORTED_DATA_PATH}") + if not os.path.exists(RAW_DATA_PATH): + os.makedirs(RAW_DATA_PATH) + logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}") +# 创建一个线程安全的锁,用于保护文件操作和共享数据 +file_lock = Lock() +open_ie_doc_lock = Lock() -def process_single_text(pg_hash, raw_data, llm_client_list): +# 创建一个事件标志,用于控制程序终止 +shutdown_event = Event() + +lpmm_entity_extract_llm = LLMRequest( + model=global_config.model.lpmm_entity_extract, + request_type="lpmm.entity_extract" +) +lpmm_rdf_build_llm = LLMRequest( + model=global_config.model.lpmm_rdf_build, + request_type="lpmm.rdf_build" +) +def process_single_text(pg_hash, raw_data): """处理单个文本的函数,用于线程池""" temp_file_path = f"{TEMP_DIR}/{pg_hash}.json" @@ -77,8 +81,8 @@ def process_single_text(pg_hash, raw_data, llm_client_list): os.remove(temp_file_path) entity_list, rdf_triple_list = info_extract_from_str( - llm_client_list[global_config["entity_extract"]["llm"]["provider"]], - llm_client_list[global_config["rdf_build"]["llm"]["provider"]], + lpmm_entity_extract_llm, + lpmm_rdf_build_llm, raw_data, ) if entity_list is None or rdf_triple_list is None: @@ -113,7 +117,7 @@ def signal_handler(_signum, _frame): def main(): # sourcery skip: comprehension-to-generator, extract-method # 设置信号处理器 signal.signal(signal.SIGINT, signal_handler) - + ensure_dirs() # 确保目录存在 # 新增用户确认提示 print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。") @@ -130,50 +134,17 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method ensure_dirs() # 确保目录存在 logger.info("--------进行信息提取--------\n") - logger.info("创建LLM客户端") - llm_client_list = { - key: LLMClient( - global_config["llm_providers"][key]["base_url"], - global_config["llm_providers"][key]["api_key"], - ) - for key in global_config["llm_providers"] - } - # 检查 openie 输出目录 - if not os.path.exists(OPENIE_OUTPUT_DIR): - os.makedirs(OPENIE_OUTPUT_DIR) - logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}") - - # 确保 TEMP_DIR 目录存在 - if not os.path.exists(TEMP_DIR): - os.makedirs(TEMP_DIR) - logger.info(f"已创建缓存目录: {TEMP_DIR}") - - # 遍历IMPORTED_DATA_PATH下所有json文件 - imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json"))) - if not imported_files: - logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件") - sys.exit(1) - - all_sha256_list = [] - all_raw_datas = [] - - for imported_file in imported_files: - logger.info(f"正在处理文件: {imported_file}") - try: - sha256_list, raw_datas = load_raw_data(imported_file) - except Exception as e: - logger.error(f"读取文件失败: {imported_file}, 错误: {e}") - continue - all_sha256_list.extend(sha256_list) - all_raw_datas.extend(raw_datas) + # 加载原始数据 + logger.info("正在加载原始数据") + all_sha256_list, all_raw_datas = load_raw_data() failed_sha256 = [] open_ie_doc = [] - workers = global_config["info_extraction"]["workers"] + workers = global_config.lpmm_knowledge.info_extraction_workers with ThreadPoolExecutor(max_workers=workers) as executor: future_to_hash = { - executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash + executor.submit(process_single_text, pg_hash, raw_data): pg_hash for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False) } diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py index ee8960f66..42a99133f 100644 --- a/scripts/raw_data_preprocessor.py +++ b/scripts/raw_data_preprocessor.py @@ -1,40 +1,16 @@ -import json import os from pathlib import Path import sys # 新增系统模块导入 -import datetime # 新增导入 - +from src.chat.knowledge.utils.hash import get_sha256 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.common.logger import get_logger -from src.chat.knowledge.lpmmconfig import global_config logger = get_logger("lpmm") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") -# 新增:确保 RAW_DATA_PATH 存在 -if not os.path.exists(RAW_DATA_PATH): - os.makedirs(RAW_DATA_PATH, exist_ok=True) - logger.info(f"已创建目录: {RAW_DATA_PATH}") +# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") -if global_config.get("persistence", {}).get("raw_data_path") is not None: - IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"]) -else: - IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") - -# 添加项目根目录到 sys.path - - -def check_and_create_dirs(): - """检查并创建必要的目录""" - required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH] - - for dir_path in required_dirs: - if not os.path.exists(dir_path): - os.makedirs(dir_path) - logger.info(f"已创建目录: {dir_path}") - - -def process_text_file(file_path): +def _process_text_file(file_path): """处理单个文本文件,返回段落列表""" with open(file_path, "r", encoding="utf-8") as f: raw = f.read() @@ -55,54 +31,45 @@ def process_text_file(file_path): return paragraphs -def main(): - # 新增用户确认提示 - print("=== 数据预处理脚本 ===") - print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。") - print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。") - print("请确保原始数据已放置在正确的目录中。") - confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != "y": - logger.info("操作已取消") - sys.exit(1) - print("\n" + "=" * 40 + "\n") - - # 检查并创建必要的目录 - check_and_create_dirs() - - # # 检查输出文件是否存在 - # if os.path.exists(RAW_DATA_PATH): - # logger.error("错误: data/import.json 已存在,请先处理或删除该文件") - # sys.exit(1) - - # if os.path.exists(RAW_DATA_PATH): - # logger.error("错误: data/openie.json 已存在,请先处理或删除该文件") - # sys.exit(1) - - # 获取所有原始文本文件 +def _process_multi_files() -> list: raw_files = list(Path(RAW_DATA_PATH).glob("*.txt")) if not raw_files: logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") sys.exit(1) - # 处理所有文件 all_paragraphs = [] for file in raw_files: logger.info(f"正在处理文件: {file.name}") - paragraphs = process_text_file(file) + paragraphs = _process_text_file(file) all_paragraphs.extend(paragraphs) + return all_paragraphs - # 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json - now = datetime.datetime.now() - filename = now.strftime("%m-%d-%H-%S-imported-data.json") - output_path = os.path.join(IMPORTED_DATA_PATH, filename) - with open(output_path, "w", encoding="utf-8") as f: - json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) +def load_raw_data() -> tuple[list[str], list[str]]: + """加载原始数据文件 - logger.info(f"处理完成,结果已保存到: {output_path}") + 读取原始数据文件,将原始数据加载到内存中 + Args: + path: 可选,指定要读取的json文件绝对路径 -if __name__ == "__main__": - logger.info(f"原始数据路径: {RAW_DATA_PATH}") - logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}") - main() + Returns: + - raw_data: 原始数据列表 + - sha256_list: 原始数据的SHA256集合 + """ + raw_data = _process_multi_files() + sha256_list = [] + sha256_set = set() + for item in raw_data: + if not isinstance(item, str): + logger.warning(f"数据类型错误:{item}") + continue + pg_hash = get_sha256(item) + if pg_hash in sha256_set: + logger.warning(f"重复数据:{item}") + continue + sha256_set.add(pg_hash) + sha256_list.append(pg_hash) + raw_data.append(item) + logger.info(f"共读取到{len(raw_data)}条数据") + + return sha256_list, raw_data \ No newline at end of file diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 578ff0172..dd9f12c0d 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -47,7 +47,7 @@ class MaiEmoji: self.embedding = [] self.hash = "" # 初始为空,在创建实例时会计算 self.description = "" - self.emotion = [] + self.emotion: List[str] = [] self.usage_count = 0 self.last_used_time = time.time() self.register_time = time.time() diff --git a/src/chat/focus_chat/heartFC_chat.py b/src/chat/focus_chat/heartFC_chat.py index 0978b4b59..dd71da7b1 100644 --- a/src/chat/focus_chat/heartFC_chat.py +++ b/src/chat/focus_chat/heartFC_chat.py @@ -243,6 +243,8 @@ class HeartFChatting: loop_start_time = time.time() await self.relationship_builder.build_relation() + available_actions = {} + # 第一步:动作修改 with Timer("动作修改", cycle_timers): try: diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 1d887e1fe..b827f4b40 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -10,8 +10,8 @@ import pandas as pd # import tqdm import faiss -from .llm_client import LLMClient -from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config +# from .llm_client import LLMClient +from .lpmmconfig import global_config from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install @@ -25,6 +25,9 @@ from rich.progress import ( SpinnerColumn, TextColumn, ) +from src.manager.local_store_manager import local_storage +from src.chat.utils.utils import get_embedding + install(extra_lines=3) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) @@ -86,21 +89,20 @@ class EmbeddingStoreItem: class EmbeddingStore: - def __init__(self, llm_client: LLMClient, namespace: str, dir_path: str): + def __init__(self, namespace: str, dir_path: str): self.namespace = namespace - self.llm_client = llm_client self.dir = dir_path - self.embedding_file_path = dir_path + "/" + namespace + ".parquet" - self.index_file_path = dir_path + "/" + namespace + ".index" + self.embedding_file_path = f"{dir_path}/{namespace}.parquet" + self.index_file_path = f"{dir_path}/{namespace}.index" self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" - self.store = dict() + self.store = {} self.faiss_index = None self.idx2hash = None def _get_embedding(self, s: str) -> List[float]: - return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) + return get_embedding(s) def get_test_file_path(self): return EMBEDDING_TEST_FILE @@ -293,20 +295,17 @@ class EmbeddingStore: class EmbeddingManager: - def __init__(self, llm_client: LLMClient): + def __init__(self): self.paragraphs_embedding_store = EmbeddingStore( - llm_client, - PG_NAMESPACE, + local_storage['pg_namespace'], EMBEDDING_DATA_DIR_STR, ) self.entities_embedding_store = EmbeddingStore( - llm_client, - ENT_NAMESPACE, + local_storage['pg_namespace'], EMBEDDING_DATA_DIR_STR, ) self.relation_embedding_store = EmbeddingStore( - llm_client, - REL_NAMESPACE, + local_storage['pg_namespace'], EMBEDDING_DATA_DIR_STR, ) self.stored_pg_hashes = set() diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index f68a848d2..bd0e17684 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -4,28 +4,35 @@ from typing import List, Union from .global_logger import logger from . import prompt_template -from .lpmmconfig import global_config, INVALID_ENTITY -from .llm_client import LLMClient -from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json +from .knowledge_lib import INVALID_ENTITY +from src.llm_models.utils_model import LLMRequest +from json_repair import repair_json +def _extract_json_from_text(text: str) -> dict: + """从文本中提取JSON数据的高容错方法""" + try: + fixed_json = repair_json(text) + if isinstance(fixed_json, str): + parsed_json = json.loads(fixed_json) + else: + parsed_json = fixed_json + if isinstance(parsed_json, list) and parsed_json: + parsed_json = parsed_json[0] -def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: + if isinstance(parsed_json, dict): + return parsed_json + + except Exception as e: + logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...") + +def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" entity_extract_context = prompt_template.build_entity_extract_context(paragraph) - _, request_result = llm_client.send_chat_request( - global_config["entity_extract"]["llm"]["model"], entity_extract_context - ) - - # 去除‘{’前的内容(结果中可能有多个‘{’) - if "[" in request_result: - request_result = request_result[request_result.index("[") :] - - # 去除最后一个‘}’后的内容(结果中可能有多个‘}’) - if "]" in request_result: - request_result = request_result[: request_result.rindex("]") + 1] - - entity_extract_result = json.loads(new_fix_broken_generated_json(request_result)) + response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context) + entity_extract_result = _extract_json_from_text(response) + # 尝试load JSON数据 + json.loads(entity_extract_result) entity_extract_result = [ entity for entity in entity_extract_result @@ -38,23 +45,16 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: return entity_extract_result -def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]: +def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" - entity_extract_context = prompt_template.build_rdf_triple_extract_context( + rdf_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=json.dumps(entities, ensure_ascii=False) ) - _, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context) - - # 去除‘{’前的内容(结果中可能有多个‘{’) - if "[" in request_result: - request_result = request_result[request_result.index("[") :] - - # 去除最后一个‘}’后的内容(结果中可能有多个‘}’) - if "]" in request_result: - request_result = request_result[: request_result.rindex("]") + 1] - - entity_extract_result = json.loads(new_fix_broken_generated_json(request_result)) + response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context) + entity_extract_result = _extract_json_from_text(response) + # 尝试load JSON数据 + json.loads(entity_extract_result) for triple in entity_extract_result: if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: raise Exception("RDF提取结果格式错误") @@ -63,7 +63,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) - def info_extract_from_str( - llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str + llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str ) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]: try_count = 0 while True: diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index 1ff651b5e..38f883d0e 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -20,24 +20,37 @@ from quick_algo import di_graph, pagerank from .utils.hash import get_sha256 from .embedding_store import EmbeddingManager, EmbeddingStoreItem -from .lpmmconfig import ( - ENT_NAMESPACE, - PG_NAMESPACE, - RAG_ENT_CNT_NAMESPACE, - RAG_GRAPH_NAMESPACE, - RAG_PG_HASH_NAMESPACE, - global_config, -) +from .lpmmconfig import global_config +from src.manager.local_store_manager import local_storage from .global_logger import logger -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) -KG_DIR = ( - os.path.join(ROOT_PATH, "data/rag") - if global_config["persistence"]["rag_data_dir"] is None - else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"]) -) -KG_DIR_STR = str(KG_DIR).replace("\\", "/") + +def _get_kg_dir(): + """ + 安全地获取KG数据目录路径 + """ + root_path = local_storage['root_path'] + if root_path is None: + # 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用 + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) + logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}") + + # 获取RAG数据目录 + rag_data_dir = global_config["persistence"]["rag_data_dir"] + if rag_data_dir is None: + kg_dir = os.path.join(root_path, "data/rag") + else: + kg_dir = os.path.join(root_path, rag_data_dir) + + return str(kg_dir).replace("\\", "/") + + +# 延迟初始化,避免在模块加载时就访问可能未初始化的 local_storage +def get_kg_dir_str(): + """获取KG目录字符串""" + return _get_kg_dir() class KGManager: @@ -46,15 +59,15 @@ class KGManager: # 存储段落的hash值,用于去重 self.stored_paragraph_hashes = set() # 实体出现次数 - self.ent_appear_cnt = dict() + self.ent_appear_cnt = {} # KG self.graph = di_graph.DiGraph() - # 持久化相关 - self.dir_path = KG_DIR_STR - self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml" - self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" - self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json" + # 持久化相关 - 使用延迟初始化的路径 + self.dir_path = get_kg_dir_str() + self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml" + self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet" + self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json" def save_to_file(self): """将KG数据保存到文件""" @@ -109,8 +122,8 @@ class KGManager: # 避免自连接 continue # 一个triple就是一条边(同时构建双向联系) - hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0]) - hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2]) + hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) + hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2]) node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 entity_set.add(hash_key1) @@ -128,8 +141,8 @@ class KGManager: """构建实体节点与文段节点之间的关系""" for idx in triple_list_data: for triple in triple_list_data[idx]: - ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0]) - pg_hash_key = PG_NAMESPACE + "-" + str(idx) + ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) + pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx) node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 @staticmethod @@ -144,8 +157,8 @@ class KGManager: ent_hash_list = set() for triple_list in triple_list_data.values(): for triple in triple_list: - ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0])) - ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2])) + ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0])) + ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2])) ent_hash_list = list(ent_hash_list) synonym_hash_set = set() @@ -250,7 +263,7 @@ class KGManager: for src_tgt in node_to_node.keys(): for node_hash in src_tgt: if node_hash not in existed_nodes: - if node_hash.startswith(ENT_NAMESPACE): + if node_hash.startswith(local_storage['ent_namespace']): # 新增实体节点 node = embedding_manager.entities_embedding_store.store[node_hash] assert isinstance(node, EmbeddingStoreItem) @@ -259,7 +272,7 @@ class KGManager: node_item["type"] = "ent" node_item["create_time"] = now_time self.graph.update_node(node_item) - elif node_hash.startswith(PG_NAMESPACE): + elif node_hash.startswith(local_storage['pg_namespace']): # 新增文段节点 node = embedding_manager.paragraphs_embedding_store.store[node_hash] assert isinstance(node, EmbeddingStoreItem) @@ -340,7 +353,7 @@ class KGManager: # 关系三元组 triple = relation[2:-2].split("', '") for ent in [(triple[0]), (triple[2])]: - ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent) + ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent) if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash not in ent_sim_scores: # 尚未记录的实体 ent_sim_scores[ent_hash] = [] @@ -418,7 +431,7 @@ class KGManager: # 获取最终结果 # 从搜索结果中提取文段节点的结果 passage_node_res = [ - (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE) + (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace']) ] del ppr_res diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index 5540d95e2..87a373a5c 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -1,4 +1,4 @@ -from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config +from src.chat.knowledge.lpmmconfig import global_config from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.mem_active_manager import MemoryActiveManager @@ -6,10 +6,80 @@ from src.chat.knowledge.qa_manager import QAManager from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.global_logger import logger from src.config.config import global_config as bot_global_config -# try: -# import quick_algo -# except ImportError: -# print("quick_algo not found, please install it first") +from src.manager.local_store_manager import local_storage +import os + +INVALID_ENTITY = [ + "", + "你", + "他", + "她", + "它", + "我们", + "你们", + "他们", + "她们", + "它们", +] +PG_NAMESPACE = "paragraph" +ENT_NAMESPACE = "entity" +REL_NAMESPACE = "relation" + +RAG_GRAPH_NAMESPACE = "rag-graph" +RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" +RAG_PG_HASH_NAMESPACE = "rag-pg-hash" + + +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +DATA_PATH = os.path.join(ROOT_PATH, "data") + +def _initialize_knowledge_local_storage(): + """ + 初始化知识库相关的本地存储配置 + 使用字典批量设置,避免重复的if判断 + """ + # 定义所有需要初始化的配置项 + default_configs = { + # 路径配置 + 'root_path': ROOT_PATH, + 'data_path': f"{ROOT_PATH}/data", + + # 实体和命名空间配置 + 'lpmm_invalid_entity': INVALID_ENTITY, + 'pg_namespace': PG_NAMESPACE, + 'ent_namespace': ENT_NAMESPACE, + 'rel_namespace': REL_NAMESPACE, + + # RAG相关命名空间配置 + 'rag_graph_namespace': RAG_GRAPH_NAMESPACE, + 'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE, + 'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE + } + + # 日志级别映射:重要配置用info,其他用debug + important_configs = {'root_path', 'data_path'} + + # 批量设置配置项 + initialized_count = 0 + for key, default_value in default_configs.items(): + if local_storage[key] is None: + local_storage[key] = default_value + + # 根据重要性选择日志级别 + if key in important_configs: + logger.info(f"设置{key}: {default_value}") + else: + logger.debug(f"设置{key}: {default_value}") + + initialized_count += 1 + + if initialized_count > 0: + logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置") + else: + logger.debug("知识库本地存储配置已存在,跳过初始化") + +# 初始化本地存储路径 +_initialize_knowledge_local_storage() # 检查LPMM知识库是否启用 if bot_global_config.lpmm_knowledge.enable: @@ -23,7 +93,7 @@ if bot_global_config.lpmm_knowledge.enable: ) # 初始化Embedding库 - embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) + embed_manager = EmbeddingManager() logger.info("正在从文件加载Embedding库") try: embed_manager.load_from_file() @@ -54,9 +124,6 @@ if bot_global_config.lpmm_knowledge.enable: qa_manager = QAManager( embed_manager, kg_manager, - llm_client_list[global_config["embedding"]["provider"]], - llm_client_list[global_config["qa"]["llm"]["provider"]], - llm_client_list[global_config["qa"]["llm"]["provider"]], ) # 记忆激活(用于记忆库) diff --git a/src/chat/knowledge/open_ie.py b/src/chat/knowledge/open_ie.py index 7bb96d131..90977fb88 100644 --- a/src/chat/knowledge/open_ie.py +++ b/src/chat/knowledge/open_ie.py @@ -4,9 +4,8 @@ import glob from typing import Any, Dict, List -from .lpmmconfig import INVALID_ENTITY, global_config - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH +# from src.manager.local_store_manager import local_storage def _filter_invalid_entities(entities: List[str]) -> List[str]: @@ -107,7 +106,7 @@ class OpenIE: @staticmethod def load() -> "OpenIE": """从OPENIE_DIR下所有json文件合并加载OpenIE数据""" - openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"]) + openie_dir = os.path.join(DATA_PATH, "openie") if not os.path.exists(openie_dir): raise Exception(f"OpenIE数据目录不存在: {openie_dir}") json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json"))) @@ -122,12 +121,6 @@ class OpenIE: openie_data = OpenIE._from_dict(data_list) return openie_data - @staticmethod - def save(openie_data: "OpenIE"): - """保存OpenIE数据到文件""" - with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f: - f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4)) - def extract_entity_dict(self): """提取实体列表""" ner_output_dict = dict( diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index 01a3e82d3..8940dbb55 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -5,11 +5,13 @@ from .global_logger import logger # from . import prompt_template from .embedding_store import EmbeddingManager -from .llm_client import LLMClient +# from .llm_client import LLMClient from .kg_manager import KGManager -from .lpmmconfig import global_config +# from .lpmmconfig import global_config from .utils.dyn_topk import dyn_select_top_k - +from src.llm_models.utils_model import LLMRequest +from src.chat.utils.utils import get_embedding +from src.config.config import global_config MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 @@ -19,26 +21,25 @@ class QAManager: self, embed_manager: EmbeddingManager, kg_manager: KGManager, - llm_client_embedding: LLMClient, - llm_client_filter: LLMClient, - llm_client_qa: LLMClient, + ): self.embed_manager = embed_manager self.kg_manager = kg_manager - self.llm_client_list = { - "embedding": llm_client_embedding, - "message_filter": llm_client_filter, - "qa": llm_client_qa, - } + # TODO: API-Adapter修改标记 + self.qa_model = LLMRequest( + model=global_config.model.lpmm_qa, + request_type="lpmm.qa" + ) def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]: """处理查询""" # 生成问题的Embedding part_start_time = time.perf_counter() - question_embedding = self.llm_client_list["embedding"].send_embedding_request( - global_config["embedding"]["model"], question - ) + question_embedding = get_embedding(question) + if question_embedding is None: + logger.error("生成问题Embedding失败") + return None part_end_time = time.perf_counter() logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s") @@ -46,14 +47,15 @@ class QAManager: part_start_time = time.perf_counter() relation_search_res = self.embed_manager.relation_embedding_store.search_top_k( question_embedding, - global_config["qa"]["params"]["relation_search_top_k"], + global_config.lpmm_knowledge.qa_relation_search_top_k, ) if relation_search_res is not None: # 过滤阈值 # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) - if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]: + if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: # 未找到相关关系 + logger.debug("未找到相关关系,跳过关系检索") relation_search_res = [] part_end_time = time.perf_counter() @@ -71,7 +73,7 @@ class QAManager: part_start_time = time.perf_counter() paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( question_embedding, - global_config["qa"]["params"]["paragraph_search_top_k"], + global_config.lpmm_knowledge.qa_paragraph_search_top_k, ) part_end_time = time.perf_counter() logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 067ae19a2..a881549f5 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -38,7 +38,9 @@ class HeartFCSender: def __init__(self): self.storage = MessageStorage() - async def send_message(self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True): + async def send_message( + self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True + ): """ 处理、发送并存储一条消息。 diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index cbd4c23ef..23f1d6948 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -79,7 +79,9 @@ class ActionPlanner: self.last_obs_time_mark = 0.0 - async def plan(self, mode: ChatMode = ChatMode.FOCUS) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension + async def plan( + self, mode: ChatMode = ChatMode.FOCUS + ) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension """ 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 """ diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 7340b6e9f..dddd8e1cc 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -479,18 +479,18 @@ class DefaultReplyer: def build_s4u_chat_history_prompts(self, message_list_before_now: list, target_user_id: str) -> tuple[str, str]: """ 构建 s4u 风格的分离对话 prompt - + Args: message_list_before_now: 历史消息列表 target_user_id: 目标用户ID(当前对话对象) - + Returns: tuple: (核心对话prompt, 背景对话prompt) """ core_dialogue_list = [] background_dialogue_list = [] bot_id = str(global_config.bot.qq_account) - + # 过滤消息:分离bot和目标用户的对话 vs 其他用户的对话 for msg_dict in message_list_before_now: try: @@ -503,11 +503,11 @@ class DefaultReplyer: background_dialogue_list.append(msg_dict) except Exception as e: logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}") - + # 构建背景对话 prompt background_dialogue_prompt = "" if background_dialogue_list: - latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size*0.6):] + latest_25_msgs = background_dialogue_list[-int(global_config.chat.max_context_size * 0.6) :] background_dialogue_prompt_str = build_readable_messages( latest_25_msgs, replace_bot_name=True, @@ -516,12 +516,12 @@ class DefaultReplyer: show_pic=False, ) background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}" - + # 构建核心对话 prompt core_dialogue_prompt = "" if core_dialogue_list: - core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size*2):] # 限制消息数量 - + core_dialogue_list = core_dialogue_list[-int(global_config.chat.max_context_size * 2) :] # 限制消息数量 + core_dialogue_prompt_str = build_readable_messages( core_dialogue_list, replace_bot_name=True, @@ -532,7 +532,7 @@ class DefaultReplyer: show_actions=True, ) core_dialogue_prompt = core_dialogue_prompt_str - + return core_dialogue_prompt, background_dialogue_prompt async def build_prompt_reply_context( @@ -578,14 +578,13 @@ class DefaultReplyer: action_description = action_info.description action_descriptions += f"- {action_name}: {action_description}\n" action_descriptions += "\n" - + message_list_before_now_long = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size * 2, ) - - + message_list_before_now = get_raw_msg_before_timestamp_with_chat( chat_id=chat_id, timestamp=time.time(), @@ -712,8 +711,6 @@ class DefaultReplyer: # 根据sender通过person_info_manager反向查找person_id,再获取user_id person_id = person_info_manager.get_person_id_by_person_name(sender) - - # 根据配置选择使用哪种 prompt 构建模式 if global_config.chat.use_s4u_prompt_mode and person_id: # 使用 s4u 对话构建模式:分离当前对话对象和其他对话 @@ -724,16 +721,15 @@ class DefaultReplyer: except Exception as e: logger.warning(f"无法从person_id {person_id} 获取user_id: {e}") target_user_id = "" - - + # 构建分离的对话 prompt core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts( message_list_before_now_long, target_user_id ) - + # 使用 s4u 风格的模板 template_name = "s4u_style_prompt" - + return await global_prompt_manager.format_prompt( template_name, expression_habits_block=expression_habits_block, diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 4433bae44..4462daba7 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -37,7 +37,7 @@ class PersonalityConfig(ConfigBase): personality_side: str """人格侧写""" - + identity: str = "" """身份特征""" @@ -106,7 +106,6 @@ class ChatConfig(ConfigBase): focus_value: float = 1.0 """麦麦的专注思考能力,越低越容易专注,消耗token也越多""" - def get_current_talk_frequency(self, chat_stream_id: Optional[str] = None) -> float: """ 根据当前时间和聊天流获取对应的 talk_frequency @@ -246,6 +245,7 @@ class ChatConfig(ConfigBase): except (ValueError, IndexError): return None + @dataclass class MessageReceiveConfig(ConfigBase): """消息接收配置类""" @@ -274,8 +274,6 @@ class NormalChatConfig(ConfigBase): """@bot 必然回复""" - - @dataclass class ExpressionConfig(ConfigBase): """表达配置类""" @@ -627,3 +625,12 @@ class ModelConfig(ConfigBase): embedding: dict[str, Any] = field(default_factory=lambda: {}) """嵌入模型配置""" + + lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {}) + """LPMM实体提取模型配置""" + + lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {}) + """LPMM RDF构建模型配置""" + + lpmm_qa: dict[str, Any] = field(default_factory=lambda: {}) + """LPMM问答模型配置""" diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 3fde2af5e..ac2281d39 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -41,11 +41,11 @@ class Individuality: personality_side: 人格侧面描述 identity: 身份细节描述 """ - bot_nickname=global_config.bot.nickname - personality_core=global_config.personality.personality_core - personality_side=global_config.personality.personality_side - identity=global_config.personality.identity - + bot_nickname = global_config.bot.nickname + personality_core = global_config.personality.personality_core + personality_side = global_config.personality.personality_side + identity = global_config.personality.identity + logger.info("正在初始化个体特征") person_info_manager = get_person_info_manager() self.bot_person_id = person_info_manager.get_person_id("system", "bot_id") @@ -146,11 +146,10 @@ class Individuality: else: logger.error("人设构建失败") - async def get_personality_block(self) -> str: person_info_manager = get_person_info_manager() bot_person_id = person_info_manager.get_person_id("system", "bot_id") - + bot_name = global_config.bot.nickname if global_config.bot.alias_names: bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" @@ -175,9 +174,8 @@ class Individuality: identity = short_impression[1] prompt_personality = f"{personality},{identity}" identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" - - return identity_block + return identity_block def _get_config_hash( self, bot_nickname: str, personality_core: str, personality_side: str, identity: list @@ -273,7 +271,6 @@ class Individuality: except IOError as e: logger.error(f"保存meta_info文件失败: {e}") - async def _create_personality(self, personality_core: str, personality_side: str) -> str: # sourcery skip: merge-list-append, move-assign """使用LLM创建压缩版本的impression diff --git a/src/individuality/personality.py b/src/individuality/personality.py index da3005ee7..87907df76 100644 --- a/src/individuality/personality.py +++ b/src/individuality/personality.py @@ -40,7 +40,15 @@ class Personality: return cls._instance @classmethod - def initialize(cls, bot_nickname: str, personality_core: str, personality_side: str, identity: List[str] = None, compress_personality: bool = True, compress_identity: bool = True) -> "Personality": + def initialize( + cls, + bot_nickname: str, + personality_core: str, + personality_side: str, + identity: List[str] = None, + compress_personality: bool = True, + compress_identity: bool = True, + ) -> "Personality": """初始化人格特质 Args: diff --git a/src/mais4u/mais4u_chat/loading.py b/src/mais4u/mais4u_chat/loading.py index 752aba0ed..64e2a89ff 100644 --- a/src/mais4u/mais4u_chat/loading.py +++ b/src/mais4u/mais4u_chat/loading.py @@ -1,6 +1,7 @@ from src.plugin_system.apis import send_api + async def send_loading(chat_id: str, content: str): await send_api.custom_to_stream( message_type="loading", @@ -9,7 +10,8 @@ async def send_loading(chat_id: str, content: str): storage_message=False, show_log=True, ) - + + async def send_unloading(chat_id: str): await send_api.custom_to_stream( message_type="loading", @@ -18,4 +20,3 @@ async def send_unloading(chat_id: str): storage_message=False, show_log=True, ) - \ No newline at end of file diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index ab6f36094..041229eb1 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -125,7 +125,7 @@ class ChatMood: ) self.last_change_time = 0 - + # 发送初始情绪状态到ws端 asyncio.create_task(self.send_emotion_update(self.mood_values)) @@ -231,10 +231,10 @@ class ChatMood: if numerical_mood_response: _old_mood_values = self.mood_values.copy() self.mood_values = numerical_mood_response - + # 发送情绪更新到ws端 await self.send_emotion_update(self.mood_values) - + logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}") self.last_change_time = message_time @@ -308,10 +308,10 @@ class ChatMood: if numerical_mood_response: _old_mood_values = self.mood_values.copy() self.mood_values = numerical_mood_response - + # 发送情绪更新到ws端 await self.send_emotion_update(self.mood_values) - + logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}") self.regression_count += 1 @@ -322,9 +322,9 @@ class ChatMood: "joy": mood_values.get("joy", 5), "anger": mood_values.get("anger", 1), "sorrow": mood_values.get("sorrow", 1), - "fear": mood_values.get("fear", 1) + "fear": mood_values.get("fear", 1), } - + await send_api.custom_to_stream( message_type="emotion", content=emotion_data, @@ -332,7 +332,7 @@ class ChatMood: storage_message=False, show_log=True, ) - + logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}") @@ -345,27 +345,27 @@ class MoodRegressionTask(AsyncTask): async def run(self): self.run_count += 1 logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态") - + now = time.time() regression_executed = 0 - + for mood in self.mood_manager.mood_list: chat_info = f"chat {mood.chat_id}" - + if mood.last_change_time == 0: logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归") continue time_since_last_change = now - mood.last_change_time - + # 检查是否有极端情绪需要快速回归 high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8} has_extreme_emotion = len(high_emotions) > 0 - + # 回归条件:1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s should_regress = False regress_reason = "" - + if time_since_last_change > 120: should_regress = True regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)" @@ -373,24 +373,28 @@ class MoodRegressionTask(AsyncTask): should_regress = True high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()]) regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)" - + if should_regress: if mood.regression_count >= 3: logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归") continue - logger.info(f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)") + logger.info( + f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)" + ) await mood.regress_mood() regression_executed += 1 else: if has_extreme_emotion: remaining_time = 5 - time_since_last_change high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()]) - logger.debug(f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒") + logger.debug( + f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒" + ) else: remaining_time = 120 - time_since_last_change logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒") - + if regression_executed > 0: logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归") else: @@ -409,11 +413,11 @@ class MoodManager: return logger.info("启动情绪管理任务...") - + # 启动情绪回归任务 regression_task = MoodRegressionTask(self) await async_task_manager.add_task(regression_task) - + self.task_started = True logger.info("情绪管理任务已启动(情绪回归)") @@ -435,7 +439,7 @@ class MoodManager: # 发送重置后的情绪状态到ws端 asyncio.create_task(mood.send_emotion_update(mood.mood_values)) return - + # 如果没有找到现有的mood,创建新的 new_mood = ChatMood(chat_id) self.mood_list.append(new_mood) diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index 09d838bdd..7a2c78042 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -107,7 +107,6 @@ class S4UStreamGenerator: model_name: str, **kwargs, ) -> AsyncGenerator[str, None]: - buffer = "" delimiters = ",。!?,.!?\n\r" # For final trimming punctuation_buffer = "" diff --git a/src/mais4u/mais4u_chat/s4u_watching_manager.py b/src/mais4u/mais4u_chat/s4u_watching_manager.py index 897ef7f70..0ef684340 100644 --- a/src/mais4u/mais4u_chat/s4u_watching_manager.py +++ b/src/mais4u/mais4u_chat/s4u_watching_manager.py @@ -43,23 +43,24 @@ logger = get_logger("watching") class WatchingState(Enum): """视线状态枚举""" + WANDERING = "wandering" # 随意看 - DANMU = "danmu" # 看弹幕 - LENS = "lens" # 看镜头 + DANMU = "danmu" # 看弹幕 + LENS = "lens" # 看镜头 class ChatWatching: def __init__(self, chat_id: str): self.chat_id: str = chat_id self.current_state: WatchingState = WatchingState.LENS # 默认看镜头 - self.last_sent_state: Optional[WatchingState] = None # 上次发送的状态 - self.state_needs_update: bool = True # 是否需要更新状态 - + self.last_sent_state: Optional[WatchingState] = None # 上次发送的状态 + self.state_needs_update: bool = True # 是否需要更新状态 + # 状态切换相关 - self.is_replying: bool = False # 是否正在生成回复 - self.reply_finished_time: Optional[float] = None # 回复完成时间 - self.danmu_viewing_duration: float = 1.0 # 看弹幕持续时间(秒) - + self.is_replying: bool = False # 是否正在生成回复 + self.reply_finished_time: Optional[float] = None # 回复完成时间 + self.danmu_viewing_duration: float = 1.0 # 看弹幕持续时间(秒) + logger.info(f"[{self.chat_id}] 视线管理器初始化,默认状态: {self.current_state.value}") async def _change_state(self, new_state: WatchingState, reason: str = ""): @@ -69,7 +70,7 @@ class ChatWatching: self.current_state = new_state self.state_needs_update = True logger.info(f"[{self.chat_id}] 视线状态切换: {old_state.value} → {new_state.value} ({reason})") - + # 立即发送视线状态更新 await self._send_watching_update() else: @@ -86,7 +87,7 @@ class ChatWatching: """开始生成回复时调用""" self.is_replying = True self.reply_finished_time = None - + if look_at_lens: await self._change_state(WatchingState.LENS, "开始生成回复-看镜头") else: @@ -96,35 +97,29 @@ class ChatWatching: """生成回复完毕时调用""" self.is_replying = False self.reply_finished_time = time.time() - + # 先看弹幕1秒 await self._change_state(WatchingState.DANMU, "回复完毕-看弹幕") logger.info(f"[{self.chat_id}] 回复完毕,将看弹幕{self.danmu_viewing_duration}秒后转为看镜头") - + # 设置定时器,1秒后自动切换到看镜头 asyncio.create_task(self._auto_switch_to_lens()) async def _auto_switch_to_lens(self): """自动切换到看镜头(延迟执行)""" await asyncio.sleep(self.danmu_viewing_duration) - + # 检查是否仍需要切换(可能状态已经被其他事件改变) - if (self.reply_finished_time is not None and - self.current_state == WatchingState.DANMU and - not self.is_replying): - + if self.reply_finished_time is not None and self.current_state == WatchingState.DANMU and not self.is_replying: await self._change_state(WatchingState.LENS, "看弹幕时间结束") self.reply_finished_time = None # 重置完成时间 async def _send_watching_update(self): """立即发送视线状态更新""" await send_api.custom_to_stream( - message_type="watching", - content=self.current_state.value, - stream_id=self.chat_id, - storage_message=False + message_type="watching", content=self.current_state.value, stream_id=self.chat_id, storage_message=False ) - + logger.info(f"[{self.chat_id}] 发送视线状态更新: {self.current_state.value}") self.last_sent_state = self.current_state self.state_needs_update = False @@ -139,11 +134,10 @@ class ChatWatching: "current_state": self.current_state.value, "is_replying": self.is_replying, "reply_finished_time": self.reply_finished_time, - "state_needs_update": self.state_needs_update + "state_needs_update": self.state_needs_update, } - class WatchingManager: def __init__(self): self.watching_list: list[ChatWatching] = [] @@ -156,7 +150,7 @@ class WatchingManager: return logger.info("启动视线管理系统...") - + self.task_started = True logger.info("视线管理系统已启动(状态变化时立即发送)") @@ -169,10 +163,10 @@ class WatchingManager: new_watching = ChatWatching(chat_id) self.watching_list.append(new_watching) logger.info(f"为chat {chat_id}创建新的视线管理器") - + # 发送初始状态 asyncio.create_task(new_watching._send_watching_update()) - + return new_watching def reset_watching_by_chat_id(self, chat_id: str): @@ -185,27 +179,24 @@ class WatchingManager: watching.is_replying = False watching.reply_finished_time = None logger.info(f"[{chat_id}] 视线状态已重置为默认状态") - + # 发送重置后的状态 asyncio.create_task(watching._send_watching_update()) return - + # 如果没有找到现有的watching,创建新的 new_watching = ChatWatching(chat_id) self.watching_list.append(new_watching) logger.info(f"为chat {chat_id}创建并重置视线管理器") - + # 发送初始状态 asyncio.create_task(new_watching._send_watching_update()) def get_all_watching_info(self) -> dict: """获取所有聊天的视线状态信息(用于调试)""" - return { - watching.chat_id: watching.get_state_info() - for watching in self.watching_list - } + return {watching.chat_id: watching.get_state_info() for watching in self.watching_list} # 全局视线管理器实例 watching_manager = WatchingManager() -"""全局视线管理器""" \ No newline at end of file +"""全局视线管理器""" diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index acd22fd5b..b47785401 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -46,10 +46,10 @@ def init_prompt(): class ChatMood: def __init__(self, chat_id: str): self.chat_id: str = chat_id - + chat_manager = get_chat_manager() self.chat_stream = chat_manager.get_stream(self.chat_id) - + self.log_prefix = f"[{self.chat_stream.group_info.group_name if self.chat_stream.group_info else self.chat_stream.user_info.user_nickname}]" self.mood_state: str = "感觉很平静" @@ -92,7 +92,7 @@ class ChatMood: chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, - limit=int(global_config.chat.max_context_size/3), + limit=int(global_config.chat.max_context_size / 3), limit_mode="last", ) chat_talking_prompt = build_readable_messages( @@ -121,14 +121,12 @@ class ChatMood: mood_state=self.mood_state, ) - - response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) if global_config.debug.show_prompt: logger.info(f"{self.log_prefix} prompt: {prompt}") logger.info(f"{self.log_prefix} response: {response}") logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}") - + logger.info(f"{self.log_prefix} 情绪状态更新为: {response}") self.mood_state = response @@ -170,15 +168,14 @@ class ChatMood: mood_state=self.mood_state, ) - response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) - + if global_config.debug.show_prompt: logger.info(f"{self.log_prefix} prompt: {prompt}") logger.info(f"{self.log_prefix} response: {response}") logger.info(f"{self.log_prefix} reasoning_content: {reasoning_content}") - - logger.info(f"{self.log_prefix} 情绪状态回归为: {response}") + + logger.info(f"{self.log_prefix} 情绪状态回归为: {response}") self.mood_state = response diff --git a/src/plugin_system/apis/chat_api.py b/src/plugin_system/apis/chat_api.py index f436c4ab5..35a210faa 100644 --- a/src/plugin_system/apis/chat_api.py +++ b/src/plugin_system/apis/chat_api.py @@ -39,7 +39,12 @@ class ChatManager: Returns: List[ChatStream]: 聊天流列表 + + Raises: + TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型 """ + if not isinstance(platform, (str, SpecialTypes)): + raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") streams = [] try: for _, stream in get_chat_manager().streams.items(): @@ -60,6 +65,8 @@ class ChatManager: Returns: List[ChatStream]: 群聊聊天流列表 """ + if not isinstance(platform, (str, SpecialTypes)): + raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") streams = [] try: for _, stream in get_chat_manager().streams.items(): @@ -79,7 +86,12 @@ class ChatManager: Returns: List[ChatStream]: 私聊聊天流列表 + + Raises: + TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型 """ + if not isinstance(platform, (str, SpecialTypes)): + raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") streams = [] try: for _, stream in get_chat_manager().streams.items(): @@ -102,7 +114,17 @@ class ChatManager: Returns: Optional[ChatStream]: 聊天流对象,如果未找到返回None + + Raises: + ValueError: 如果 group_id 为空字符串 + TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes """ + if not isinstance(group_id, str): + raise TypeError("group_id 必须是字符串类型") + if not isinstance(platform, (str, SpecialTypes)): + raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") + if not group_id: + raise ValueError("group_id 不能为空") try: for _, stream in get_chat_manager().streams.items(): if ( @@ -129,7 +151,17 @@ class ChatManager: Returns: Optional[ChatStream]: 聊天流对象,如果未找到返回None + + Raises: + ValueError: 如果 user_id 为空字符串 + TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes """ + if not isinstance(user_id, str): + raise TypeError("user_id 必须是字符串类型") + if not isinstance(platform, (str, SpecialTypes)): + raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") + if not user_id: + raise ValueError("user_id 不能为空") try: for _, stream in get_chat_manager().streams.items(): if ( @@ -153,9 +185,15 @@ class ChatManager: Returns: str: 聊天类型 ("group", "private", "unknown") + + Raises: + TypeError: 如果 chat_stream 不是 ChatStream 类型 + ValueError: 如果 chat_stream 为空 """ + if not isinstance(chat_stream, ChatStream): + raise TypeError("chat_stream 必须是 ChatStream 类型") if not chat_stream: - raise ValueError("chat_stream cannot be None") + raise ValueError("chat_stream 不能为 None") if hasattr(chat_stream, "group_info"): return "group" if chat_stream.group_info else "private" @@ -170,9 +208,15 @@ class ChatManager: Returns: Dict[str, Any]: 聊天流信息字典 + + Raises: + TypeError: 如果 chat_stream 不是 ChatStream 类型 + ValueError: 如果 chat_stream 为空 """ if not chat_stream: - return {} + raise ValueError("chat_stream 不能为 None") + if not isinstance(chat_stream, ChatStream): + raise TypeError("chat_stream 必须是 ChatStream 类型") try: info: Dict[str, Any] = { diff --git a/src/plugin_system/apis/emoji_api.py b/src/plugin_system/apis/emoji_api.py index 4f1d03521..cafb52df8 100644 --- a/src/plugin_system/apis/emoji_api.py +++ b/src/plugin_system/apis/emoji_api.py @@ -8,6 +8,8 @@ count = emoji_api.get_count() """ +import random + from typing import Optional, Tuple, List from src.common.logger import get_logger from src.chat.emoji_system.emoji_manager import get_emoji_manager @@ -29,7 +31,15 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]] Returns: Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None + + Raises: + ValueError: 如果描述为空字符串 + TypeError: 如果描述不是字符串类型 """ + if not description: + raise ValueError("描述不能为空") + if not isinstance(description, str): + raise TypeError("描述必须是字符串类型") try: logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}") @@ -55,7 +65,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]] return None -async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]: +async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, str]]]: """随机获取指定数量的表情包 Args: @@ -63,8 +73,17 @@ async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]: Returns: Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,如果失败则为None + + Raises: + TypeError: 如果count不是整数类型 + ValueError: 如果count为负数 """ - if count <= 0: + if not isinstance(count, int): + raise TypeError("count 必须是整数类型") + if count < 0: + raise ValueError("count 不能为负数") + if count == 0: + logger.warning("[EmojiAPI] count 为0,返回空列表") return [] try: @@ -90,8 +109,6 @@ async def get_random(count: int = 1) -> Optional[List[Tuple[str, str, str]]]: count = len(valid_emojis) # 随机选择 - import random - selected_emojis = random.sample(valid_emojis, count) results = [] @@ -128,7 +145,15 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: Returns: Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None + + Raises: + ValueError: 如果情感标签为空字符串 + TypeError: 如果情感标签不是字符串类型 """ + if not emotion: + raise ValueError("情感标签不能为空") + if not isinstance(emotion, str): + raise TypeError("情感标签必须是字符串类型") try: logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}") @@ -146,8 +171,6 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: return None # 随机选择匹配的表情包 - import random - selected_emoji = random.choice(matching_emojis) emoji_base64 = image_path_to_base64(selected_emoji.full_path) @@ -185,11 +208,11 @@ def get_count() -> int: return 0 -def get_info() -> dict: +def get_info(): """获取表情包系统信息 Returns: - dict: 包含表情包数量、最大数量等信息 + dict: 包含表情包数量、最大数量、可用数量信息 """ try: emoji_manager = get_emoji_manager() @@ -203,7 +226,7 @@ def get_info() -> dict: return {"current_count": 0, "max_count": 0, "available_emojis": 0} -def get_emotions() -> list: +def get_emotions() -> List[str]: """获取所有可用的情感标签 Returns: @@ -223,7 +246,7 @@ def get_emotions() -> list: return [] -def get_descriptions() -> list: +def get_descriptions() -> List[str]: """获取所有表情包描述 Returns: diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 6c8cc01da..4763dbd1b 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -5,11 +5,12 @@ 使用方式: from src.plugin_system.apis import generator_api replyer = generator_api.get_replyer(chat_stream) - success, reply_set = await generator_api.generate_reply(chat_stream, action_data, reasoning) + success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning) """ import traceback from typing import Tuple, Any, Dict, List, Optional +from rich.traceback import install from src.common.logger import get_logger from src.chat.replyer.default_generator import DefaultReplyer from src.chat.message_receive.chat_stream import ChatStream @@ -17,6 +18,8 @@ from src.chat.utils.utils import process_llm_response from src.chat.replyer.replyer_manager import replyer_manager from src.plugin_system.base.component_types import ActionInfo +install(extra_lines=3) + logger = get_logger("generator_api") @@ -44,7 +47,12 @@ def get_replyer( Returns: Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None + + Raises: + ValueError: chat_stream 和 chat_id 均为空 """ + if not chat_id and not chat_stream: + raise ValueError("chat_stream 和 chat_id 不可均为空") try: logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}") return replyer_manager.get_replyer( diff --git a/src/plugin_system/apis/llm_api.py b/src/plugin_system/apis/llm_api.py index 1bcd1f7d2..4c45a38f0 100644 --- a/src/plugin_system/apis/llm_api.py +++ b/src/plugin_system/apis/llm_api.py @@ -14,7 +14,6 @@ from src.config.config import global_config logger = get_logger("llm_api") - # ============================================================================= # LLM模型API函数 # ============================================================================= @@ -31,8 +30,21 @@ def get_available_models() -> Dict[str, Any]: logger.error("[LLMAPI] 无法获取模型列表:全局配置中未找到 model 配置") return {} + # 自动获取所有属性并转换为字典形式 + rets = {} models = global_config.model - return models + attrs = dir(models) + for attr in attrs: + if not attr.startswith("__"): + try: + value = getattr(models, attr) + if not callable(value): # 排除方法 + rets[attr] = value + except Exception as e: + logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}") + continue + return rets + except Exception as e: logger.error(f"[LLMAPI] 获取可用模型失败: {e}") return {} diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index 3b4738c24..5e0e3e4be 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -114,7 +114,11 @@ async def _send_to_target( # 发送消息 sent_msg = await heart_fc_sender.send_message( - bot_message, typing=typing, set_reply=(anchor_message is not None), storage_message=storage_message, show_log=show_log + bot_message, + typing=typing, + set_reply=(anchor_message is not None), + storage_message=storage_message, + show_log=show_log, ) if sent_msg: @@ -362,7 +366,9 @@ async def custom_to_stream( Returns: bool: 是否发送成功 """ - return await _send_to_target(message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log) + return await _send_to_target( + message_type, content, stream_id, display_message, typing, reply_to, storage_message, show_log + ) async def text_to_group( diff --git a/src/plugins/built_in/core_actions/plugin.py b/src/plugins/built_in/core_actions/plugin.py index 83b0abfda..edcee0574 100644 --- a/src/plugins/built_in/core_actions/plugin.py +++ b/src/plugins/built_in/core_actions/plugin.py @@ -75,7 +75,7 @@ class ReplyAction(BaseAction): reply_to = self.action_data.get("reply_to", "") sender, target = self._parse_reply_target(reply_to) - + try: prepared_reply = self.action_data.get("prepared_reply", "") if not prepared_reply: