From 62d9cd3cd5056453238aaf4c40cacd8b337bd74a Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Tue, 13 May 2025 14:18:15 +0800 Subject: [PATCH 1/3] =?UTF-8?q?fix=EF=BC=9A=E4=BF=AE=E5=A4=8D=E7=BB=9F?= =?UTF-8?q?=E8=AE=A1=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/utils/statistic.py | 125 ++++++++++++++++----------------- 1 file changed, 60 insertions(+), 65 deletions(-) diff --git a/src/plugins/utils/statistic.py b/src/plugins/utils/statistic.py index 4c11ba3d8..6a0b95964 100644 --- a/src/plugins/utils/statistic.py +++ b/src/plugins/utils/statistic.py @@ -109,13 +109,13 @@ def _format_online_time(online_seconds: int) -> str: minutes = (total_oneline_time.seconds // 60) % 60 seconds = total_oneline_time.seconds % 60 if days > 0: - # 如果在线时间超过1天,则格式化为“X天X小时X分钟” + # 如果在线时间超过1天,则格式化为"X天X小时X分钟" total_oneline_time_str = f"{total_oneline_time.days}天{hours}小时{minutes}分钟{seconds}秒" elif hours > 0: - # 如果在线时间超过1小时,则格式化为“X小时X分钟X秒” + # 如果在线时间超过1小时,则格式化为"X小时X分钟X秒" total_oneline_time_str = f"{hours}小时{minutes}分钟{seconds}秒" else: - # 其他情况格式化为“X分钟X秒” + # 其他情况格式化为"X分钟X秒" total_oneline_time_str = f"{minutes}分钟{seconds}秒" return total_oneline_time_str @@ -151,7 +151,7 @@ class StatisticOutputTask(AsyncTask): local_storage["deploy_time"] = now.timestamp() self.stat_period: List[Tuple[str, timedelta, str]] = [ - ("all_time", now - deploy_time, "自部署以来"), # 必须保留“all_time” + ("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time" ("last_7_days", timedelta(days=7), "最近7天"), ("last_24_hours", timedelta(days=1), "最近24小时"), ("last_hour", timedelta(hours=1), "最近1小时"), @@ -511,37 +511,64 @@ class StatisticOutputTask(AsyncTask): """ # format总在线时间 + # 按模型分类统计 + model_rows = "\n".join([ + f"" + f"{model_name}" + f"{count}" + f"{stat_data[IN_TOK_BY_MODEL][model_name]}" + f"{stat_data[OUT_TOK_BY_MODEL][model_name]}" + f"{stat_data[TOTAL_TOK_BY_MODEL][model_name]}" + f"{stat_data[COST_BY_MODEL][model_name]:.4f} ¥" + f"" + for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items()) + ]) + # 按请求类型分类统计 + type_rows = "\n".join([ + f"" + f"{req_type}" + f"{count}" + f"{stat_data[IN_TOK_BY_TYPE][req_type]}" + f"{stat_data[OUT_TOK_BY_TYPE][req_type]}" + f"{stat_data[TOTAL_TOK_BY_TYPE][req_type]}" + f"{stat_data[COST_BY_TYPE][req_type]:.4f} ¥" + f"" + for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items()) + ]) + # 按用户分类统计 + user_rows = "\n".join([ + f"" + f"{user_id}" + f"{count}" + f"{stat_data[IN_TOK_BY_USER][user_id]}" + f"{stat_data[OUT_TOK_BY_USER][user_id]}" + f"{stat_data[TOTAL_TOK_BY_USER][user_id]}" + f"{stat_data[COST_BY_USER][user_id]:.4f} ¥" + f"" + for user_id, count in sorted(stat_data[REQ_CNT_BY_USER].items()) + ]) + # 聊天消息统计 + chat_rows = "\n".join([ + f"{self.name_mapping[chat_id][0]}{count}" + for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()) + ]) # 生成HTML return f""" -
-

+

+

统计时段: {start_time.strftime("%Y-%m-%d %H:%M:%S")} ~ {now.strftime("%Y-%m-%d %H:%M:%S")}

-

总在线时间: {_format_online_time(stat_data[ONLINE_TIME])}

-

总消息数: {stat_data[TOTAL_MSG_CNT]}

-

总请求数: {stat_data[TOTAL_REQ_CNT]}

-

总花费: {stat_data[TOTAL_COST]:.4f} ¥

+

总在线时间: {_format_online_time(stat_data[ONLINE_TIME])}

+

总消息数: {stat_data[TOTAL_MSG_CNT]}

+

总请求数: {stat_data[TOTAL_REQ_CNT]}

+

总花费: {stat_data[TOTAL_COST]:.4f} ¥

按模型分类统计

- { - "\n".join( - [ - f"" - f"" - f"" - f"" - f"" - f"" - f"" - f"" - for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items()) - ] - ) - } + {model_rows}
模型名称调用次数输入Token输出TokenToken总量累计花费
{model_name}{count}{stat_data[IN_TOK_BY_MODEL][model_name]}{stat_data[OUT_TOK_BY_MODEL][model_name]}{stat_data[TOTAL_TOK_BY_MODEL][model_name]}{stat_data[COST_BY_MODEL][model_name]:.4f} ¥
@@ -551,21 +578,7 @@ class StatisticOutputTask(AsyncTask): 请求类型调用次数输入Token输出TokenToken总量累计花费 - { - "\n".join( - [ - f"" - f"{req_type}" - f"{count}" - f"{stat_data[IN_TOK_BY_TYPE][req_type]}" - f"{stat_data[OUT_TOK_BY_TYPE][req_type]}" - f"{stat_data[TOTAL_TOK_BY_TYPE][req_type]}" - f"{stat_data[COST_BY_TYPE][req_type]:.4f} ¥" - f"" - for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items()) - ] - ) - } + {type_rows} @@ -575,21 +588,7 @@ class StatisticOutputTask(AsyncTask): 用户名称调用次数输入Token输出TokenToken总量累计花费 - { - "\n".join( - [ - f"" - f"{user_id}" - f"{count}" - f"{stat_data[IN_TOK_BY_USER][user_id]}" - f"{stat_data[OUT_TOK_BY_USER][user_id]}" - f"{stat_data[TOTAL_TOK_BY_USER][user_id]}" - f"{stat_data[COST_BY_USER][user_id]:.4f} ¥" - f"" - for user_id, count in sorted(stat_data[REQ_CNT_BY_USER].items()) - ] - ) - } + {user_rows} @@ -599,14 +598,7 @@ class StatisticOutputTask(AsyncTask): 联系人/群组名称消息数量 - { - "\n".join( - [ - f"{self.name_mapping[chat_id][0]}{count}" - for chat_id, count in sorted(stat_data[MSG_CNT_BY_CHAT].items()) - ] - ) - } + {chat_rows}
@@ -622,6 +614,9 @@ class StatisticOutputTask(AsyncTask): _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) ) + joined_tab_list = "\n".join(tab_list) + joined_tab_content = "\n".join(tab_content_list) + html_template = ( """ @@ -734,10 +729,10 @@ class StatisticOutputTask(AsyncTask):

统计截止时间: {now.strftime("%Y-%m-%d %H:%M:%S")}

- {"\n".join(tab_list)} + {joined_tab_list}
- {"\n".join(tab_content_list)} + {joined_tab_content}
""" + """ From 5dc4616442611142e5781e7e8ced076f19e75a23 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Tue, 13 May 2025 14:20:20 +0800 Subject: [PATCH 2/3] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index a8c972ab4..040a445bf 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ llm_tool_benchmark_results.json MaiBot-Napcat-Adapter-main MaiBot-Napcat-Adapter /test +/log_debug /src/test nonebot-maibot-adapter/ *.zip From a7c235c557421a9acf6f3a1287fbbf291c9c385e Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Tue, 13 May 2025 22:14:26 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dlpmm=20template=E7=9A=84?= =?UTF-8?q?=E4=B8=80=E4=BA=9B=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/import_openie.py | 35 ++++++++--------- scripts/info_extraction.py | 61 +++++++++++++----------------- template/lpmm_config_template.toml | 2 +- 3 files changed, 44 insertions(+), 54 deletions(-) diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 472667c14..16bf1aa72 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -21,11 +21,7 @@ from src.plugins.knowledge.src.utils.hash import get_sha256 # 添加项目根目录到 sys.path ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -OPENIE_DIR = ( - global_config["persistence"]["openie_data_path"] - if global_config["persistence"]["openie_data_path"] - else os.path.join(ROOT_PATH, "data/openie") -) +OPENIE_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data/openie") logger = get_module_logger("OpenIE导入") @@ -49,14 +45,14 @@ def hash_deduplicate( new_triple_list_data: 去重后的三元组 """ # 保存去重后的段落 - new_raw_paragraphs = dict() + new_raw_paragraphs = {} # 保存去重后的三元组 - new_triple_list_data = dict() + new_triple_list_data = {} for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())): # 段落hash paragraph_hash = get_sha256(raw_paragraph) - if ((PG_NAMESPACE + "-" + paragraph_hash) in stored_pg_hashes) and (paragraph_hash in stored_paragraph_hashes): + if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: continue new_raw_paragraphs[paragraph_hash] = raw_paragraph new_triple_list_data[paragraph_hash] = triple_list @@ -65,6 +61,7 @@ def hash_deduplicate( def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool: + # sourcery skip: extract-method # 从OpenIE数据中提取段落原文与三元组列表 # 索引的段落原文 raw_paragraphs = openie_data.extract_raw_paragraph_dict() @@ -117,7 +114,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k return False # 新增:提示用户是否删除非法文段继续导入 # 将print移到所有logger.error之后,确保不会被冲掉 - logger.info("\n检测到非法文段,共{}条。".format(len(missing_idxs))) + logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。") logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="") user_choice = input().strip().lower() if user_choice != "y": @@ -133,10 +130,10 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k raw_paragraphs = openie_data.extract_raw_paragraph_dict() entity_list_data = openie_data.extract_entity_dict() triple_list_data = openie_data.extract_triple_dict() - # 再次校验 - if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): - logger.error("删除非法文段后,数据仍不一致,程序终止。") - sys.exit(1) + # 再次校验 + if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): + logger.error("删除非法文段后,数据仍不一致,程序终止。") + sys.exit(1) # 将索引换为对应段落的hash值 logger.info("正在进行段落去重与重索引") raw_paragraphs, triple_list_data = hash_deduplicate( @@ -166,7 +163,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k return True -def main(): +def main(): # sourcery skip: dict-comprehension # 新增确认提示 print("=== 重要操作确认 ===") print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") @@ -185,7 +182,7 @@ def main(): logger.info("----开始导入openie数据----\n") logger.info("创建LLM客户端") - llm_client_list = dict() + llm_client_list = {} for key in global_config["llm_providers"]: llm_client_list[key] = LLMClient( global_config["llm_providers"][key]["base_url"], @@ -198,7 +195,7 @@ def main(): try: embed_manager.load_from_file() except Exception as e: - logger.error("从文件加载Embedding库时发生错误:{}".format(e)) + logger.error(f"从文件加载Embedding库时发生错误:{e}") if "嵌入模型与本地存储不一致" in str(e): logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") @@ -213,7 +210,7 @@ def main(): try: kg_manager.load_from_file() except Exception as e: - logger.error("从文件加载KG时发生错误:{}".format(e)) + logger.error(f"从文件加载KG时发生错误:{e}") logger.error("如果你是第一次导入知识,请忽略此错误") logger.info("KG加载完成") @@ -222,7 +219,7 @@ def main(): # 数据比对:Embedding库与KG的段落hash集合 for pg_hash in kg_manager.stored_paragraph_hashes: - key = PG_NAMESPACE + "-" + pg_hash + key = f"{PG_NAMESPACE}-{pg_hash}" if key not in embed_manager.stored_pg_hashes: logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") @@ -230,7 +227,7 @@ def main(): try: openie_data = OpenIE.load() except Exception as e: - logger.error("导入OpenIE数据文件时发生错误:{}".format(e)) + logger.error(f"导入OpenIE数据文件时发生错误:{e}") return False if handle_import_openie(openie_data, embed_manager, kg_manager) is False: logger.error("处理OpenIE数据时发生错误") diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index 2191d1a95..44ded983a 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -33,16 +33,10 @@ logger = get_module_logger("LPMM知识库-信息提取") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") -IMPORTED_DATA_PATH = ( - global_config["persistence"]["raw_data_path"] - if global_config["persistence"]["raw_data_path"] - else os.path.join(ROOT_PATH, "data/imported_lpmm_data") -) -OPENIE_OUTPUT_DIR = ( - global_config["persistence"]["openie_data_path"] - if global_config["persistence"]["openie_data_path"] - else os.path.join(ROOT_PATH, "data/openie") +IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join( + ROOT_PATH, "data/imported_lpmm_data" ) +OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data/openie") # 创建一个线程安全的锁,用于保护文件操作和共享数据 file_lock = Lock() @@ -76,26 +70,25 @@ def process_single_text(pg_hash, raw_data, llm_client_list): ) if entity_list is None or rdf_triple_list is None: return None, pg_hash - else: - doc_item = { - "idx": pg_hash, - "passage": raw_data, - "extracted_entities": entity_list, - "extracted_triples": rdf_triple_list, - } - # 保存临时提取结果 - with file_lock: - try: - with open(temp_file_path, "w", encoding="utf-8") as f: - json.dump(doc_item, f, ensure_ascii=False, indent=4) - except Exception as e: - logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}") - # 如果保存失败,确保不会留下损坏的文件 - if os.path.exists(temp_file_path): - os.remove(temp_file_path) - sys.exit(0) - return None, pg_hash - return doc_item, None + doc_item = { + "idx": pg_hash, + "passage": raw_data, + "extracted_entities": entity_list, + "extracted_triples": rdf_triple_list, + } + # 保存临时提取结果 + with file_lock: + try: + with open(temp_file_path, "w", encoding="utf-8") as f: + json.dump(doc_item, f, ensure_ascii=False, indent=4) + except Exception as e: + logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}") + # 如果保存失败,确保不会留下损坏的文件 + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + sys.exit(0) + return None, pg_hash + return doc_item, None def signal_handler(_signum, _frame): @@ -104,7 +97,7 @@ def signal_handler(_signum, _frame): sys.exit(0) -def main(): +def main(): # sourcery skip: comprehension-to-generator, extract-method # 设置信号处理器 signal.signal(signal.SIGINT, signal_handler) @@ -125,13 +118,13 @@ def main(): logger.info("--------进行信息提取--------\n") logger.info("创建LLM客户端") - llm_client_list = dict() - for key in global_config["llm_providers"]: - llm_client_list[key] = LLMClient( + llm_client_list = { + key: LLMClient( global_config["llm_providers"][key]["base_url"], global_config["llm_providers"][key]["api_key"], ) - + for key in global_config["llm_providers"] + } # 检查 openie 输出目录 if not os.path.exists(OPENIE_OUTPUT_DIR): os.makedirs(OPENIE_OUTPUT_DIR) diff --git a/template/lpmm_config_template.toml b/template/lpmm_config_template.toml index 745cbaaf6..5bf24732a 100644 --- a/template/lpmm_config_template.toml +++ b/template/lpmm_config_template.toml @@ -54,7 +54,7 @@ res_top_k = 3 # 最终提供的文段TopK [persistence] # 持久化配置(存储中间数据,防止重复计算) data_root_path = "data" # 数据根目录 -raw_data_path = "data/imported_lpmm_data" # 原始数据路径 +imported_data_path = "data/imported_lpmm_data" # 转换为json的raw文件数据路径 openie_data_path = "data/openie" # OpenIE数据路径 embedding_data_dir = "data/embedding" # 嵌入数据目录 rag_data_dir = "data/rag" # RAG数据目录