From 3bef6f4babc9ee5ac64beeb6aeacc78d12840e57 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Tue, 19 Aug 2025 20:41:00 +0800 Subject: [PATCH] =?UTF-8?q?fix(embedding):=20=E5=BD=BB=E5=BA=95=E8=A7=A3?= =?UTF-8?q?=E5=86=B3=E4=BA=8B=E4=BB=B6=E5=BE=AA=E7=8E=AF=E5=86=B2=E7=AA=81?= =?UTF-8?q?=E5=AF=BC=E8=87=B4=E7=9A=84=E5=B5=8C=E5=85=A5=E7=94=9F=E6=88=90?= =?UTF-8?q?=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 通过以下改动修复嵌入生成过程中的事件循环相关问题: - 在 EmbeddingStore._get_embedding 中,改为同步创建-使用-销毁的新事件循环模式,彻底避免嵌套事件循环问题 - 调整批量嵌入 _get_embeddings_batch_threaded,确保每个线程使用独立、短生命周期的事件循环 - 新增 force_new 参数,LLM 请求嵌入任务时强制创建新的客户端实例,减少跨循环对象复用 - 在 OpenAI 客户端的 embedding 调用处补充详细日志,方便排查网络连接异常 - get_embedding() 每次都重建 LLMRequest,降低实例在多个事件循环中穿梭的概率 此次改动虽然以同步风格“硬掰”异步接口,但对现有接口零破坏,确保了向量数据库及相关知识检索功能的稳定性。(还有就是把的脚本文件夹移回来了) --- scripts/expression_stats.py | 208 +++ scripts/import_openie.py | 268 ++++ scripts/info_extraction.py | 217 +++ scripts/interest_value_analysis.py | 287 ++++ scripts/log_viewer_optimized.py | 1428 ++++++++++++++++++ scripts/manifest_tool.py | 237 +++ scripts/mongodb_to_sqlite.py | 920 +++++++++++ scripts/raw_data_preprocessor.py | 75 + scripts/run.sh | 556 +++++++ scripts/run_lpmm.sh | 51 + scripts/text_length_analysis.py | 394 +++++ src/chat/knowledge/embedding_store.py | 54 +- src/chat/utils/utils.py | 1 + src/llm_models/model_client/base_client.py | 11 +- src/llm_models/model_client/openai_client.py | 6 + src/llm_models/utils_model.py | 5 +- 16 files changed, 4695 insertions(+), 23 deletions(-) create mode 100644 scripts/expression_stats.py create mode 100644 scripts/import_openie.py create mode 100644 scripts/info_extraction.py create mode 100644 scripts/interest_value_analysis.py create mode 100644 scripts/log_viewer_optimized.py create mode 100644 scripts/manifest_tool.py create mode 100644 scripts/mongodb_to_sqlite.py create mode 100644 scripts/raw_data_preprocessor.py create mode 100644 scripts/run.sh create mode 100644 scripts/run_lpmm.sh create mode 100644 scripts/text_length_analysis.py diff --git a/scripts/expression_stats.py b/scripts/expression_stats.py new file mode 100644 index 000000000..4e761d8d1 --- /dev/null +++ b/scripts/expression_stats.py @@ -0,0 +1,208 @@ +import time +import sys +import os +from typing import Dict, List + +# Add project root to Python path +from src.common.database.database_model import Expression, ChatStreams +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) + + + + +def get_chat_name(chat_id: str) -> str: + """Get chat name from chat_id by querying ChatStreams table directly""" + try: + # 直接从数据库查询ChatStreams表 + chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) + if chat_stream is None: + return f"未知聊天 ({chat_id})" + + # 如果有群组信息,显示群组名称 + if chat_stream.group_name: + return f"{chat_stream.group_name} ({chat_id})" + # 如果是私聊,显示用户昵称 + elif chat_stream.user_nickname: + return f"{chat_stream.user_nickname}的私聊 ({chat_id})" + else: + return f"未知聊天 ({chat_id})" + except Exception: + return f"查询失败 ({chat_id})" + + +def calculate_time_distribution(expressions) -> Dict[str, int]: + """Calculate distribution of last active time in days""" + now = time.time() + distribution = { + '0-1天': 0, + '1-3天': 0, + '3-7天': 0, + '7-14天': 0, + '14-30天': 0, + '30-60天': 0, + '60-90天': 0, + '90+天': 0 + } + for expr in expressions: + diff_days = (now - expr.last_active_time) / (24*3600) + if diff_days < 1: + distribution['0-1天'] += 1 + elif diff_days < 3: + distribution['1-3天'] += 1 + elif diff_days < 7: + distribution['3-7天'] += 1 + elif diff_days < 14: + distribution['7-14天'] += 1 + elif diff_days < 30: + distribution['14-30天'] += 1 + elif diff_days < 60: + distribution['30-60天'] += 1 + elif diff_days < 90: + distribution['60-90天'] += 1 + else: + distribution['90+天'] += 1 + return distribution + + +def calculate_count_distribution(expressions) -> Dict[str, int]: + """Calculate distribution of count values""" + distribution = { + '0-1': 0, + '1-2': 0, + '2-3': 0, + '3-4': 0, + '4-5': 0, + '5-10': 0, + '10+': 0 + } + for expr in expressions: + cnt = expr.count + if cnt < 1: + distribution['0-1'] += 1 + elif cnt < 2: + distribution['1-2'] += 1 + elif cnt < 3: + distribution['2-3'] += 1 + elif cnt < 4: + distribution['3-4'] += 1 + elif cnt < 5: + distribution['4-5'] += 1 + elif cnt < 10: + distribution['5-10'] += 1 + else: + distribution['10+'] += 1 + return distribution + + +def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> List[Expression]: + """Get top N most used expressions for a specific chat_id""" + return (Expression.select() + .where(Expression.chat_id == chat_id) + .order_by(Expression.count.desc()) + .limit(top_n)) + + +def show_overall_statistics(expressions, total: int) -> None: + """Show overall statistics""" + time_dist = calculate_time_distribution(expressions) + count_dist = calculate_count_distribution(expressions) + + print("\n=== 总体统计 ===") + print(f"总表达式数量: {total}") + + print("\n上次激活时间分布:") + for period, count in time_dist.items(): + print(f"{period}: {count} ({count/total*100:.2f}%)") + + print("\ncount分布:") + for range_, count in count_dist.items(): + print(f"{range_}: {count} ({count/total*100:.2f}%)") + + +def show_chat_statistics(chat_id: str, chat_name: str) -> None: + """Show statistics for a specific chat""" + chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id)) + chat_total = len(chat_exprs) + + print(f"\n=== {chat_name} ===") + print(f"表达式数量: {chat_total}") + + if chat_total == 0: + print("该聊天没有表达式数据") + return + + # Time distribution for this chat + time_dist = calculate_time_distribution(chat_exprs) + print("\n上次激活时间分布:") + for period, count in time_dist.items(): + if count > 0: + print(f"{period}: {count} ({count/chat_total*100:.2f}%)") + + # Count distribution for this chat + count_dist = calculate_count_distribution(chat_exprs) + print("\ncount分布:") + for range_, count in count_dist.items(): + if count > 0: + print(f"{range_}: {count} ({count/chat_total*100:.2f}%)") + + # Top expressions + print("\nTop 10使用最多的表达式:") + top_exprs = get_top_expressions_by_chat(chat_id, 10) + for i, expr in enumerate(top_exprs, 1): + print(f"{i}. [{expr.type}] Count: {expr.count}") + print(f" Situation: {expr.situation}") + print(f" Style: {expr.style}") + print() + + +def interactive_menu() -> None: + """Interactive menu for expression statistics""" + # Get all expressions + expressions = list(Expression.select()) + if not expressions: + print("数据库中没有找到表达式") + return + + total = len(expressions) + + # Get unique chat_ids and their names + chat_ids = list(set(expr.chat_id for expr in expressions)) + chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids] + chat_info.sort(key=lambda x: x[1]) # Sort by chat name + + while True: + print("\n" + "="*50) + print("表达式统计分析") + print("="*50) + print("0. 显示总体统计") + + for i, (chat_id, chat_name) in enumerate(chat_info, 1): + chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id) + print(f"{i}. {chat_name} ({chat_count}个表达式)") + + print("q. 退出") + + choice = input("\n请选择要查看的统计 (输入序号): ").strip() + + if choice.lower() == 'q': + print("再见!") + break + + try: + choice_num = int(choice) + if choice_num == 0: + show_overall_statistics(expressions, total) + elif 1 <= choice_num <= len(chat_info): + chat_id, chat_name = chat_info[choice_num - 1] + show_chat_statistics(chat_id, chat_name) + else: + print("无效的选择,请重新输入") + except ValueError: + print("请输入有效的数字") + + input("\n按回车键继续...") + + +if __name__ == "__main__": + interactive_menu() \ No newline at end of file diff --git a/scripts/import_openie.py b/scripts/import_openie.py new file mode 100644 index 000000000..c4367892a --- /dev/null +++ b/scripts/import_openie.py @@ -0,0 +1,268 @@ +# try: +# import src.plugins.knowledge.lib.quick_algo +# except ImportError: +# print("未找到quick_algo库,无法使用quick_algo算法") +# print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace") + +import sys +import os +import asyncio +from time import sleep + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from src.chat.knowledge.embedding_store import EmbeddingManager +from src.chat.knowledge.open_ie import OpenIE +from src.chat.knowledge.kg_manager import KGManager +from src.common.logger import get_logger +from src.chat.knowledge.utils.hash import get_sha256 + + +# 添加项目根目录到 sys.path +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") + +logger = get_logger("OpenIE导入") + +def ensure_openie_dir(): + """确保OpenIE数据目录存在""" + if not os.path.exists(OPENIE_DIR): + os.makedirs(OPENIE_DIR) + logger.info(f"创建OpenIE数据目录:{OPENIE_DIR}") + else: + logger.info(f"OpenIE数据目录已存在:{OPENIE_DIR}") + + +def hash_deduplicate( + raw_paragraphs: dict[str, str], + triple_list_data: dict[str, list[list[str]]], + stored_pg_hashes: set, + stored_paragraph_hashes: set, +): + """Hash去重 + + Args: + raw_paragraphs: 索引的段落原文 + triple_list_data: 索引的三元组列表 + stored_pg_hashes: 已存储的段落hash集合 + stored_paragraph_hashes: 已存储的段落hash集合 + + Returns: + new_raw_paragraphs: 去重后的段落 + new_triple_list_data: 去重后的三元组 + """ + # 保存去重后的段落 + new_raw_paragraphs = {} + # 保存去重后的三元组 + new_triple_list_data = {} + + for _, (raw_paragraph, triple_list) in enumerate( + zip(raw_paragraphs.values(), triple_list_data.values(), strict=False) + ): + # 段落hash + paragraph_hash = get_sha256(raw_paragraph) + # 使用与EmbeddingStore中一致的命名空间格式:namespace-hash + paragraph_key = f"paragraph-{paragraph_hash}" + if paragraph_key in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: + continue + new_raw_paragraphs[paragraph_hash] = raw_paragraph + new_triple_list_data[paragraph_hash] = triple_list + + return new_raw_paragraphs, new_triple_list_data + + +def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool: + # sourcery skip: extract-method + # 从OpenIE数据中提取段落原文与三元组列表 + # 索引的段落原文 + raw_paragraphs = openie_data.extract_raw_paragraph_dict() + # 索引的实体列表 + entity_list_data = openie_data.extract_entity_dict() + # 索引的三元组列表 + triple_list_data = openie_data.extract_triple_dict() + # print(openie_data.docs) + if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): + logger.error("OpenIE数据存在异常") + logger.error(f"原始段落数量:{len(raw_paragraphs)}") + logger.error(f"实体列表数量:{len(entity_list_data)}") + logger.error(f"三元组列表数量:{len(triple_list_data)}") + logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致") + logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况") + logger.error("或者一段中只有符号的情况") + # 新增:检查docs中每条数据的完整性 + logger.error("系统将于2秒后开始检查数据完整性") + sleep(2) + found_missing = False + missing_idxs = [] + for doc in getattr(openie_data, "docs", []): + idx = doc.get("idx", "<无idx>") + passage = doc.get("passage", "<无passage>") + missing = [] + # 检查字段是否存在且非空 + if "passage" not in doc or not doc.get("passage"): + missing.append("passage") + if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list): + missing.append("名词列表缺失") + elif len(doc.get("extracted_entities", [])) == 0: + missing.append("名词列表为空") + if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list): + missing.append("主谓宾三元组缺失") + elif len(doc.get("extracted_triples", [])) == 0: + missing.append("主谓宾三元组为空") + # 输出所有doc的idx + # print(f"检查: idx={idx}") + if missing: + found_missing = True + missing_idxs.append(idx) + logger.error("\n") + logger.error("数据缺失:") + logger.error(f"对应哈希值:{idx}") + logger.error(f"对应文段内容内容:{passage}") + logger.error(f"非法原因:{', '.join(missing)}") + # 确保提示在所有非法数据输出后再输出 + if not found_missing: + logger.info("所有数据均完整,没有发现缺失字段。") + return False + # 新增:提示用户是否删除非法文段继续导入 + # 将print移到所有logger.error之后,确保不会被冲掉 + logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。") + logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="") + user_choice = input().strip().lower() + if user_choice != "y": + logger.info("用户选择不删除非法文段,程序终止。") + sys.exit(1) + # 删除非法文段 + logger.info("正在删除非法文段并继续导入...") + # 过滤掉非法文段 + openie_data.docs = [ + doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs + ] + # 重新提取数据 + raw_paragraphs = openie_data.extract_raw_paragraph_dict() + entity_list_data = openie_data.extract_entity_dict() + triple_list_data = openie_data.extract_triple_dict() + # 再次校验 + if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): + logger.error("删除非法文段后,数据仍不一致,程序终止。") + sys.exit(1) + # 将索引换为对应段落的hash值 + logger.info("正在进行段落去重与重索引") + raw_paragraphs, triple_list_data = hash_deduplicate( + raw_paragraphs, + triple_list_data, + embed_manager.stored_pg_hashes, + kg_manager.stored_paragraph_hashes, + ) + if len(raw_paragraphs) != 0: + # 获取嵌入并保存 + logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}") + logger.info("开始Embedding") + embed_manager.store_new_data_set(raw_paragraphs, triple_list_data) + # Embedding-Faiss重索引 + logger.info("正在重新构建向量索引") + embed_manager.rebuild_faiss_index() + logger.info("向量索引构建完成") + embed_manager.save_to_file() + logger.info("Embedding完成") + # 构建新段落的RAG + logger.info("开始构建RAG") + kg_manager.build_kg(triple_list_data, embed_manager) + kg_manager.save_to_file() + logger.info("RAG构建完成") + else: + logger.info("无新段落需要处理") + return True + + +async def main_async(): # sourcery skip: dict-comprehension + # 新增确认提示 + print("=== 重要操作确认 ===") + print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") + print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快") + print("推荐使用硅基流动的Pro/BAAI/bge-m3") + print("每百万Token费用为0.7元") + print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行") + print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G") + confirm = input("确认继续执行?(y/n): ").strip().lower() + if confirm != "y": + logger.info("用户取消操作") + print("操作已取消") + sys.exit(1) + print("\n" + "=" * 40 + "\n") + ensure_openie_dir() # 确保OpenIE目录存在 + logger.info("----开始导入openie数据----\n") + + logger.info("创建LLM客户端") + + # 初始化Embedding库 + embed_manager = EmbeddingManager() + logger.info("正在从文件加载Embedding库") + try: + embed_manager.load_from_file() + except Exception as e: + logger.error(f"从文件加载Embedding库时发生错误:{e}") + if "嵌入模型与本地存储不一致" in str(e): + logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") + logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") + # print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") + sys.exit(1) + if "不存在" in str(e): + logger.error("如果你是第一次导入知识,请忽略此错误") + logger.info("Embedding库加载完成") + # 初始化KG + kg_manager = KGManager() + logger.info("正在从文件加载KG") + try: + kg_manager.load_from_file() + except Exception as e: + logger.error(f"从文件加载KG时发生错误:{e}") + logger.error("如果你是第一次导入知识,请忽略此错误") + logger.info("KG加载完成") + + logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") + logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}") + + # 数据比对:Embedding库与KG的段落hash集合 + for pg_hash in kg_manager.stored_paragraph_hashes: + # 使用与EmbeddingStore中一致的命名空间格式:namespace-hash + key = f"paragraph-{pg_hash}" + if key not in embed_manager.stored_pg_hashes: + logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") + + logger.info("正在导入OpenIE数据文件") + try: + openie_data = OpenIE.load() + except Exception as e: + logger.error(f"导入OpenIE数据文件时发生错误:{e}") + return False + if handle_import_openie(openie_data, embed_manager, kg_manager) is False: + logger.error("处理OpenIE数据时发生错误") + return False + return None + + +def main(): + """主函数 - 设置新的事件循环并运行异步主函数""" + # 检查是否有现有的事件循环 + try: + loop = asyncio.get_running_loop() + if loop.is_closed(): + # 如果事件循环已关闭,创建新的 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + # 没有运行的事件循环,创建新的 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # 在新的事件循环中运行异步主函数 + loop.run_until_complete(main_async()) + finally: + # 确保事件循环被正确关闭 + if not loop.is_closed(): + loop.close() + + +if __name__ == "__main__": + # logger.info(f"111111111111111111111111{ROOT_PATH}") + main() diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py new file mode 100644 index 000000000..47ad55a8b --- /dev/null +++ b/scripts/info_extraction.py @@ -0,0 +1,217 @@ +import json +import os +import signal +from concurrent.futures import ThreadPoolExecutor, as_completed +from threading import Lock, Event +import sys +import datetime + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +# 添加项目根目录到 sys.path + +from rich.progress import Progress # 替换为 rich 进度条 + +from src.common.logger import get_logger +# from src.chat.knowledge.lpmmconfig import global_config +from src.chat.knowledge.ie_process import info_extract_from_str +from src.chat.knowledge.open_ie import OpenIE +from rich.progress import ( + BarColumn, + TimeElapsedColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, + SpinnerColumn, + TextColumn, +) +from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest + +logger = get_logger("LPMM知识库-信息提取") + + +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +TEMP_DIR = os.path.join(ROOT_PATH, "temp") +# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") +OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") + +def ensure_dirs(): + """确保临时目录和输出目录存在""" + if not os.path.exists(TEMP_DIR): + os.makedirs(TEMP_DIR) + logger.info(f"已创建临时目录: {TEMP_DIR}") + if not os.path.exists(OPENIE_OUTPUT_DIR): + os.makedirs(OPENIE_OUTPUT_DIR) + logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}") + if not os.path.exists(RAW_DATA_PATH): + os.makedirs(RAW_DATA_PATH) + logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}") + +# 创建一个线程安全的锁,用于保护文件操作和共享数据 +file_lock = Lock() +open_ie_doc_lock = Lock() + +# 创建一个事件标志,用于控制程序终止 +shutdown_event = Event() + +lpmm_entity_extract_llm = LLMRequest( + model_set=model_config.model_task_config.lpmm_entity_extract, + request_type="lpmm.entity_extract" +) +lpmm_rdf_build_llm = LLMRequest( + model_set=model_config.model_task_config.lpmm_rdf_build, + request_type="lpmm.rdf_build" +) +def process_single_text(pg_hash, raw_data): + """处理单个文本的函数,用于线程池""" + temp_file_path = f"{TEMP_DIR}/{pg_hash}.json" + + # 使用文件锁检查和读取缓存文件 + with file_lock: + if os.path.exists(temp_file_path): + try: + # 存在对应的提取结果 + logger.info(f"找到缓存的提取结果:{pg_hash}") + with open(temp_file_path, "r", encoding="utf-8") as f: + return json.load(f), None + except json.JSONDecodeError: + # 如果JSON文件损坏,删除它并重新处理 + logger.warning(f"缓存文件损坏,重新处理:{pg_hash}") + os.remove(temp_file_path) + + entity_list, rdf_triple_list = info_extract_from_str( + lpmm_entity_extract_llm, + lpmm_rdf_build_llm, + raw_data, + ) + if entity_list is None or rdf_triple_list is None: + return None, pg_hash + doc_item = { + "idx": pg_hash, + "passage": raw_data, + "extracted_entities": entity_list, + "extracted_triples": rdf_triple_list, + } + # 保存临时提取结果 + with file_lock: + try: + with open(temp_file_path, "w", encoding="utf-8") as f: + json.dump(doc_item, f, ensure_ascii=False, indent=4) + except Exception as e: + logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}") + # 如果保存失败,确保不会留下损坏的文件 + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + sys.exit(0) + return None, pg_hash + return doc_item, None + + +def signal_handler(_signum, _frame): + """处理Ctrl+C信号""" + logger.info("\n接收到中断信号,正在优雅地关闭程序...") + sys.exit(0) + + +def main(): # sourcery skip: comprehension-to-generator, extract-method + # 设置信号处理器 + signal.signal(signal.SIGINT, signal_handler) + ensure_dirs() # 确保目录存在 + # 新增用户确认提示 + print("=== 重要操作确认,请认真阅读以下内容哦 ===") + print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。") + print("举例:600万字全剧情,提取选用deepseek v3 0324,消耗约40元,约3小时。") + print("建议使用硅基流动的非Pro模型") + print("或者使用可以用赠金抵扣的Pro模型") + print("请确保账户余额充足,并且在执行前确认无误。") + confirm = input("确认继续执行?(y/n): ").strip().lower() + if confirm != "y": + logger.info("用户取消操作") + print("操作已取消") + sys.exit(1) + print("\n" + "=" * 40 + "\n") + ensure_dirs() # 确保目录存在 + logger.info("--------进行信息提取--------\n") + + # 加载原始数据 + logger.info("正在加载原始数据") + all_sha256_list, all_raw_datas = load_raw_data() + + failed_sha256 = [] + open_ie_doc = [] + + workers = global_config.lpmm_knowledge.info_extraction_workers + with ThreadPoolExecutor(max_workers=workers) as executor: + future_to_hash = { + executor.submit(process_single_text, pg_hash, raw_data): pg_hash + for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False) + } + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + "•", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + transient=False, + ) as progress: + task = progress.add_task("正在进行提取:", total=len(future_to_hash)) + try: + for future in as_completed(future_to_hash): + if shutdown_event.is_set(): + for f in future_to_hash: + if not f.done(): + f.cancel() + break + + doc_item, failed_hash = future.result() + if failed_hash: + failed_sha256.append(failed_hash) + logger.error(f"提取失败:{failed_hash}") + elif doc_item: + with open_ie_doc_lock: + open_ie_doc.append(doc_item) + progress.update(task, advance=1) + except KeyboardInterrupt: + logger.info("\n接收到中断信号,正在优雅地关闭程序...") + shutdown_event.set() + for f in future_to_hash: + if not f.done(): + f.cancel() + + # 合并所有文件的提取结果并保存 + if open_ie_doc: + sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) + sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) + num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc]) + openie_obj = OpenIE( + open_ie_doc, + round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0, + round(sum_phrase_words / num_phrases, 4) if num_phrases else 0, + ) + # 输出文件名格式:MM-DD-HH-ss-openie.json + now = datetime.datetime.now() + filename = now.strftime("%m-%d-%H-%S-openie.json") + output_path = os.path.join(OPENIE_OUTPUT_DIR, filename) + with open(output_path, "w", encoding="utf-8") as f: + json.dump( + openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__, + f, + ensure_ascii=False, + indent=4, + ) + logger.info(f"信息提取结果已保存到: {output_path}") + else: + logger.warning("没有可保存的信息提取结果") + + logger.info("--------信息提取完成--------") + logger.info(f"提取失败的文段SHA256:{failed_sha256}") + + +if __name__ == "__main__": + main() diff --git a/scripts/interest_value_analysis.py b/scripts/interest_value_analysis.py new file mode 100644 index 000000000..fba1f160d --- /dev/null +++ b/scripts/interest_value_analysis.py @@ -0,0 +1,287 @@ +import time +import sys +import os +from typing import Dict, List, Tuple, Optional +from datetime import datetime +# Add project root to Python path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) +from src.common.database.database_model import Messages, ChatStreams #noqa + + + + +def get_chat_name(chat_id: str) -> str: + """Get chat name from chat_id by querying ChatStreams table directly""" + try: + chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) + if chat_stream is None: + return f"未知聊天 ({chat_id})" + + if chat_stream.group_name: + return f"{chat_stream.group_name} ({chat_id})" + elif chat_stream.user_nickname: + return f"{chat_stream.user_nickname}的私聊 ({chat_id})" + else: + return f"未知聊天 ({chat_id})" + except Exception: + return f"查询失败 ({chat_id})" + + +def format_timestamp(timestamp: float) -> str: + """Format timestamp to readable date string""" + try: + return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, OSError): + return "未知时间" + + +def calculate_interest_value_distribution(messages) -> Dict[str, int]: + """Calculate distribution of interest_value""" + distribution = { + '0.000-0.010': 0, + '0.010-0.050': 0, + '0.050-0.100': 0, + '0.100-0.500': 0, + '0.500-1.000': 0, + '1.000-2.000': 0, + '2.000-5.000': 0, + '5.000-10.000': 0, + '10.000+': 0 + } + + for msg in messages: + if msg.interest_value is None or msg.interest_value == 0.0: + continue + + value = float(msg.interest_value) + if value < 0.010: + distribution['0.000-0.010'] += 1 + elif value < 0.050: + distribution['0.010-0.050'] += 1 + elif value < 0.100: + distribution['0.050-0.100'] += 1 + elif value < 0.500: + distribution['0.100-0.500'] += 1 + elif value < 1.000: + distribution['0.500-1.000'] += 1 + elif value < 2.000: + distribution['1.000-2.000'] += 1 + elif value < 5.000: + distribution['2.000-5.000'] += 1 + elif value < 10.000: + distribution['5.000-10.000'] += 1 + else: + distribution['10.000+'] += 1 + + return distribution + + +def get_interest_value_stats(messages) -> Dict[str, float]: + """Calculate basic statistics for interest_value""" + values = [float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0] + + if not values: + return { + 'count': 0, + 'min': 0, + 'max': 0, + 'avg': 0, + 'median': 0 + } + + values.sort() + count = len(values) + + return { + 'count': count, + 'min': min(values), + 'max': max(values), + 'avg': sum(values) / count, + 'median': values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2 + } + + +def get_available_chats() -> List[Tuple[str, str, int]]: + """Get all available chats with message counts""" + try: + # 获取所有有消息的chat_id + chat_counts = {} + for msg in Messages.select(Messages.chat_id).distinct(): + chat_id = msg.chat_id + count = Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.interest_value.is_null(False)) & + (Messages.interest_value != 0.0) + ).count() + if count > 0: + chat_counts[chat_id] = count + + # 获取聊天名称 + result = [] + for chat_id, count in chat_counts.items(): + chat_name = get_chat_name(chat_id) + result.append((chat_id, chat_name, count)) + + # 按消息数量排序 + result.sort(key=lambda x: x[2], reverse=True) + return result + except Exception as e: + print(f"获取聊天列表失败: {e}") + return [] + + +def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: + """Get time range input from user""" + print("\n时间范围选择:") + print("1. 最近1天") + print("2. 最近3天") + print("3. 最近7天") + print("4. 最近30天") + print("5. 自定义时间范围") + print("6. 不限制时间") + + choice = input("请选择时间范围 (1-6): ").strip() + + now = time.time() + + if choice == "1": + return now - 24*3600, now + elif choice == "2": + return now - 3*24*3600, now + elif choice == "3": + return now - 7*24*3600, now + elif choice == "4": + return now - 30*24*3600, now + elif choice == "5": + print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") + start_str = input().strip() + print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):") + end_str = input().strip() + + try: + start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp() + end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp() + return start_time, end_time + except ValueError: + print("时间格式错误,将不限制时间范围") + return None, None + else: + return None, None + + +def analyze_interest_values(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None: + """Analyze interest values with optional filters""" + + # 构建查询条件 + query = Messages.select().where( + (Messages.interest_value.is_null(False)) & + (Messages.interest_value != 0.0) + ) + + if chat_id: + query = query.where(Messages.chat_id == chat_id) + + if start_time: + query = query.where(Messages.time >= start_time) + + if end_time: + query = query.where(Messages.time <= end_time) + + messages = list(query) + + if not messages: + print("没有找到符合条件的消息") + return + + # 计算统计信息 + distribution = calculate_interest_value_distribution(messages) + stats = get_interest_value_stats(messages) + + # 显示结果 + print("\n=== Interest Value 分析结果 ===") + if chat_id: + print(f"聊天: {get_chat_name(chat_id)}") + else: + print("聊天: 全部聊天") + + if start_time and end_time: + print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}") + elif start_time: + print(f"时间范围: {format_timestamp(start_time)} 之后") + elif end_time: + print(f"时间范围: {format_timestamp(end_time)} 之前") + else: + print("时间范围: 不限制") + + print("\n基本统计:") + print(f"有效消息数量: {stats['count']} (排除null和0值)") + print(f"最小值: {stats['min']:.3f}") + print(f"最大值: {stats['max']:.3f}") + print(f"平均值: {stats['avg']:.3f}") + print(f"中位数: {stats['median']:.3f}") + + print("\nInterest Value 分布:") + total = stats['count'] + for range_name, count in distribution.items(): + if count > 0: + percentage = count / total * 100 + print(f"{range_name}: {count} ({percentage:.2f}%)") + + +def interactive_menu() -> None: + """Interactive menu for interest value analysis""" + + while True: + print("\n" + "="*50) + print("Interest Value 分析工具") + print("="*50) + print("1. 分析全部聊天") + print("2. 选择特定聊天分析") + print("q. 退出") + + choice = input("\n请选择分析模式 (1-2, q): ").strip() + + if choice.lower() == 'q': + print("再见!") + break + + chat_id = None + + if choice == "2": + # 显示可用的聊天列表 + chats = get_available_chats() + if not chats: + print("没有找到有interest_value数据的聊天") + continue + + print(f"\n可用的聊天 (共{len(chats)}个):") + for i, (_cid, name, count) in enumerate(chats, 1): + print(f"{i}. {name} ({count}条有效消息)") + + try: + chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip()) + if 1 <= chat_choice <= len(chats): + chat_id = chats[chat_choice - 1][0] + else: + print("无效选择") + continue + except ValueError: + print("请输入有效数字") + continue + + elif choice != "1": + print("无效选择") + continue + + # 获取时间范围 + start_time, end_time = get_time_range_input() + + # 执行分析 + analyze_interest_values(chat_id, start_time, end_time) + + input("\n按回车键继续...") + + +if __name__ == "__main__": + interactive_menu() \ No newline at end of file diff --git a/scripts/log_viewer_optimized.py b/scripts/log_viewer_optimized.py new file mode 100644 index 000000000..d93f50166 --- /dev/null +++ b/scripts/log_viewer_optimized.py @@ -0,0 +1,1428 @@ +import tkinter as tk +from tkinter import ttk, messagebox, filedialog, colorchooser +import json +from pathlib import Path +import threading +import toml +from datetime import datetime +from collections import defaultdict +import os +import time + + +class LogIndex: + """日志索引,用于快速检索和过滤""" + + def __init__(self): + self.entries = [] # 所有日志条目 + self.module_index = defaultdict(list) # 按模块索引 + self.level_index = defaultdict(list) # 按级别索引 + self.filtered_indices = [] # 当前过滤结果的索引 + self.total_entries = 0 + + def add_entry(self, index, entry): + """添加日志条目到索引""" + if index >= len(self.entries): + self.entries.extend([None] * (index - len(self.entries) + 1)) + + self.entries[index] = entry + self.total_entries = max(self.total_entries, index + 1) + + # 更新各种索引 + logger_name = entry.get("logger_name", "") + level = entry.get("level", "") + + self.module_index[logger_name].append(index) + self.level_index[level].append(index) + + def filter_entries(self, modules=None, level=None, search_text=None): + """根据条件过滤日志条目""" + if not modules and not level and not search_text: + self.filtered_indices = list(range(self.total_entries)) + return self.filtered_indices + + candidate_indices = set(range(self.total_entries)) + + # 模块过滤 + if modules and "全部" not in modules: + module_indices = set() + for module in modules: + module_indices.update(self.module_index.get(module, [])) + candidate_indices &= module_indices + + # 级别过滤 + if level and level != "全部": + level_indices = set(self.level_index.get(level, [])) + candidate_indices &= level_indices + + # 文本搜索过滤 + if search_text: + search_text = search_text.lower() + text_indices = set() + for i in candidate_indices: + if i < len(self.entries) and self.entries[i]: + entry = self.entries[i] + text_content = f"{entry.get('logger_name', '')} {entry.get('event', '')}".lower() + if search_text in text_content: + text_indices.add(i) + candidate_indices &= text_indices + + self.filtered_indices = sorted(list(candidate_indices)) + return self.filtered_indices + + def get_filtered_count(self): + """获取过滤后的条目数量""" + return len(self.filtered_indices) + + def get_entry_at_filtered_position(self, position): + """获取过滤结果中指定位置的条目""" + if 0 <= position < len(self.filtered_indices): + index = self.filtered_indices[position] + return self.entries[index] if index < len(self.entries) else None + return None + + +class LogFormatter: + """日志格式化器""" + + def __init__(self, config, custom_module_colors=None, custom_level_colors=None): + self.config = config + + # 日志级别颜色 + self.level_colors = { + "debug": "#FFA500", + "info": "#0000FF", + "success": "#008000", + "warning": "#FFFF00", + "error": "#FF0000", + "critical": "#800080", + } + + # 模块颜色映射 + self.module_colors = { + "api": "#00FF00", + "emoji": "#00FF00", + "chat": "#0080FF", + "config": "#FFFF00", + "common": "#FF00FF", + "tools": "#00FFFF", + "lpmm": "#00FFFF", + "plugin_system": "#FF0080", + "experimental": "#FFFFFF", + "person_info": "#008000", + "individuality": "#000080", + "manager": "#800080", + "llm_models": "#008080", + "plugins": "#800000", + "plugin_api": "#808000", + "remote": "#8000FF", + } + + # 应用自定义颜色 + if custom_module_colors: + self.module_colors.update(custom_module_colors) + if custom_level_colors: + self.level_colors.update(custom_level_colors) + + # 根据配置决定颜色启用状态 + color_text = self.config.get("color_text", "full") + if color_text == "none": + self.enable_colors = False + self.enable_module_colors = False + self.enable_level_colors = False + elif color_text == "title": + self.enable_colors = True + self.enable_module_colors = True + self.enable_level_colors = False + elif color_text == "full": + self.enable_colors = True + self.enable_module_colors = True + self.enable_level_colors = True + else: + self.enable_colors = True + self.enable_module_colors = True + self.enable_level_colors = False + + def format_log_entry(self, log_entry): + """格式化日志条目,返回格式化后的文本和样式标签""" + timestamp = log_entry.get("timestamp", "") + level = log_entry.get("level", "info") + logger_name = log_entry.get("logger_name", "") + event = log_entry.get("event", "") + + # 格式化时间戳 + formatted_timestamp = self.format_timestamp(timestamp) + + # 构建输出部分 + parts = [] + tags = [] + + # 日志级别样式配置 + log_level_style = self.config.get("log_level_style", "lite") + + # 时间戳 + if formatted_timestamp: + if log_level_style == "lite" and self.enable_level_colors: + parts.append(formatted_timestamp) + tags.append(f"level_{level}") + else: + parts.append(formatted_timestamp) + tags.append("timestamp") + + # 日志级别显示 + if log_level_style == "full": + level_text = f"[{level.upper():>8}]" + parts.append(level_text) + if self.enable_level_colors: + tags.append(f"level_{level}") + else: + tags.append("level") + elif log_level_style == "compact": + level_text = f"[{level.upper()[0]:>8}]" + parts.append(level_text) + if self.enable_level_colors: + tags.append(f"level_{level}") + else: + tags.append("level") + + # 模块名称 + if logger_name: + module_text = f"[{logger_name}]" + parts.append(module_text) + if self.enable_module_colors: + tags.append(f"module_{logger_name}") + else: + tags.append("module") + + # 消息内容 + if isinstance(event, str): + parts.append(event) + elif isinstance(event, dict): + try: + parts.append(json.dumps(event, ensure_ascii=False, indent=None)) + except (TypeError, ValueError): + parts.append(str(event)) + else: + parts.append(str(event)) + tags.append("message") + + # 处理其他字段 + extras = [] + for key, value in log_entry.items(): + if key not in ("timestamp", "level", "logger_name", "event"): + if isinstance(value, (dict, list)): + try: + value_str = json.dumps(value, ensure_ascii=False, indent=None) + except (TypeError, ValueError): + value_str = str(value) + else: + value_str = str(value) + extras.append(f"{key}={value_str}") + + if extras: + parts.append(" ".join(extras)) + tags.append("extras") + + return parts, tags + + def format_timestamp(self, timestamp): + """格式化时间戳""" + if not timestamp: + return "" + + try: + if "T" in timestamp: + dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) + else: + return timestamp + + date_style = self.config.get("date_style", "m-d H:i:s") + format_map = { + "Y": "%Y", + "m": "%m", + "d": "%d", + "H": "%H", + "i": "%M", + "s": "%S", + } + + python_format = date_style + for php_char, python_char in format_map.items(): + python_format = python_format.replace(php_char, python_char) + + return dt.strftime(python_format) + except Exception: + return timestamp + + +class VirtualLogDisplay: + """虚拟滚动日志显示组件""" + + def __init__(self, parent, formatter): + self.parent = parent + self.formatter = formatter + self.line_height = 20 # 每行高度(像素) + self.visible_lines = 30 # 可见行数 + + # 创建主框架 + self.main_frame = ttk.Frame(parent) + + # 创建文本框和滚动条 + self.scrollbar = ttk.Scrollbar(self.main_frame) + self.scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + self.text_widget = tk.Text( + self.main_frame, + wrap=tk.WORD, + yscrollcommand=self.scrollbar.set, + background="#1e1e1e", + foreground="#ffffff", + insertbackground="#ffffff", + selectbackground="#404040", + font=("Consolas", 10), + ) + self.text_widget.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + self.scrollbar.config(command=self.text_widget.yview) + + # 配置文本标签样式 + self.configure_text_tags() + + # 数据源 + self.log_index = None + self.current_page = 0 + self.page_size = 500 # 每页显示条数 + self.max_display_lines = 2000 # 最大显示行数 + + def pack(self, **kwargs): + """包装pack方法""" + self.main_frame.pack(**kwargs) + + def configure_text_tags(self): + """配置文本标签样式""" + # 基础标签 + self.text_widget.tag_configure("timestamp", foreground="#808080") + self.text_widget.tag_configure("level", foreground="#808080") + self.text_widget.tag_configure("module", foreground="#808080") + self.text_widget.tag_configure("message", foreground="#ffffff") + self.text_widget.tag_configure("extras", foreground="#808080") + + # 日志级别颜色标签 + for level, color in self.formatter.level_colors.items(): + self.text_widget.tag_configure(f"level_{level}", foreground=color) + + # 模块颜色标签 + for module, color in self.formatter.module_colors.items(): + self.text_widget.tag_configure(f"module_{module}", foreground=color) + + def set_log_index(self, log_index): + """设置日志索引数据源""" + self.log_index = log_index + self.current_page = 0 + self.refresh_display() + + def refresh_display(self): + """刷新显示""" + if not self.log_index: + self.text_widget.delete(1.0, tk.END) + return + + # 清空显示 + self.text_widget.delete(1.0, tk.END) + + # 批量加载和显示日志 + total_count = self.log_index.get_filtered_count() + if total_count == 0: + self.text_widget.insert(tk.END, "没有符合条件的日志记录\n") + return + + # 计算显示范围 + start_index = 0 + end_index = min(total_count, self.max_display_lines) + + # 批量处理和显示 + batch_size = 100 + for batch_start in range(start_index, end_index, batch_size): + batch_end = min(batch_start + batch_size, end_index) + self.display_batch(batch_start, batch_end) + + # 让UI有机会响应 + self.parent.update_idletasks() + + # 滚动到底部(如果需要) + self.text_widget.see(tk.END) + + def display_batch(self, start_index, end_index): + """批量显示日志条目""" + for i in range(start_index, end_index): + log_entry = self.log_index.get_entry_at_filtered_position(i) + if log_entry: + self.append_entry(log_entry, scroll=False) + + def append_entry(self, log_entry, scroll=True): + """将单个日志条目附加到文本小部件""" + # 检查在添加新内容之前视图是否已滚动到底部 + should_scroll = scroll and self.text_widget.yview()[1] > 0.99 + + parts, tags = self.formatter.format_log_entry(log_entry) + line_text = " ".join(parts) + "\n" + + # 获取插入前的末尾位置 + start_pos = self.text_widget.index(tk.END + "-1c") + self.text_widget.insert(tk.END, line_text) + + # 为每个部分应用正确的标签 + current_len = 0 + for part, tag_name in zip(parts, tags, strict=False): + start_index = f"{start_pos}+{current_len}c" + end_index = f"{start_pos}+{current_len + len(part)}c" + self.text_widget.tag_add(tag_name, start_index, end_index) + current_len += len(part) + 1 # 计入空格 + + if should_scroll: + self.text_widget.see(tk.END) + + +class AsyncLogLoader: + """异步日志加载器""" + + def __init__(self, callback): + self.callback = callback + self.loading = False + self.should_stop = False + + def load_file_async(self, file_path, progress_callback=None): + """异步加载日志文件""" + if self.loading: + return + + self.loading = True + self.should_stop = False + + def load_worker(): + try: + log_index = LogIndex() + + if not os.path.exists(file_path): + self.callback(log_index, "文件不存在") + return + + file_size = os.path.getsize(file_path) + processed_size = 0 + + with open(file_path, "r", encoding="utf-8") as f: + line_count = 0 + batch_size = 1000 # 批量处理 + + while not self.should_stop: + lines = [] + for _ in range(batch_size): + line = f.readline() + if not line: + break + lines.append(line) + processed_size += len(line.encode("utf-8")) + + if not lines: + break + + # 处理这批数据 + for line in lines: + try: + log_entry = json.loads(line.strip()) + log_index.add_entry(line_count, log_entry) + line_count += 1 + except json.JSONDecodeError: + continue + + # 更新进度 + if progress_callback: + progress = min(100, (processed_size / file_size) * 100) + progress_callback(progress, line_count) + + if not self.should_stop: + self.callback(log_index, None) + + except Exception as e: + self.callback(None, str(e)) + finally: + self.loading = False + + thread = threading.Thread(target=load_worker) + thread.daemon = True + thread.start() + + def stop_loading(self): + """停止加载""" + self.should_stop = True + self.loading = False + + +class LogViewer: + def __init__(self, root): + self.root = root + self.root.title("MaiBot日志查看器 (优化版)") + self.root.geometry("1200x800") + + # 加载配置 + self.load_config() + + # 初始化日志格式化器 + self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) + + # 初始化日志文件路径 + self.current_log_file = Path("logs/app.log.jsonl") + self.last_file_size = 0 + self.watching_thread = None + self.is_watching = tk.BooleanVar(value=True) + + # 初始化异步加载器 + self.async_loader = AsyncLogLoader(self.on_file_loaded) + + # 初始化日志索引 + self.log_index = LogIndex() + + # 创建主框架 + self.main_frame = ttk.Frame(root) + self.main_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) + + # 创建菜单栏 + self.create_menu() + + # 创建控制面板 + self.create_control_panel() + + # 创建虚拟滚动日志显示区域 + self.log_display = VirtualLogDisplay(self.main_frame, self.formatter) + self.log_display.pack(fill=tk.BOTH, expand=True) + + # 模块名映射 + self.module_name_mapping = { + "api": "API接口", + "async_task_manager": "异步任务管理器", + "background_tasks": "后台任务", + "base_tool": "基础工具", + "chat_stream": "聊天流", + "component_registry": "组件注册器", + "config": "配置", + "database_model": "数据库模型", + "emoji": "表情", + "heartflow": "心流", + "local_storage": "本地存储", + "lpmm": "LPMM", + "maibot_statistic": "MaiBot统计", + "main_message": "主消息", + "main": "主程序", + "memory": "内存", + "mood": "情绪", + "plugin_manager": "插件管理器", + "remote": "远程", + "willing": "意愿", + } + + # 加载自定义映射 + self.load_module_mapping() + + # 选中的模块集合 + self.selected_modules = set() + self.modules = set() + + # 绑定事件 + self.level_combo.bind("<>", self.filter_logs) + self.search_var.trace("w", self.filter_logs) + + # 绑定快捷键 + self.root.bind("", lambda e: self.select_log_file()) + self.root.bind("", lambda e: self.refresh_log_file()) + self.root.bind("", lambda e: self.export_logs()) + + # 初始加载文件 + if self.current_log_file.exists(): + self.load_log_file_async() + + def load_config(self): + """加载配置文件""" + # 默认配置 + self.default_config = { + "log": {"date_style": "m-d H:i:s", "log_level_style": "lite", "color_text": "full", "log_level": "INFO"}, + "viewer": { + "theme": "dark", + "font_size": 10, + "max_lines": 1000, + "auto_scroll": True, + "show_milliseconds": False, + "window": {"width": 1200, "height": 800, "remember_position": True}, + }, + } + + # 从bot_config.toml加载日志配置 + config_path = Path("config/bot_config.toml") + self.log_config = self.default_config["log"].copy() + self.viewer_config = self.default_config["viewer"].copy() + + try: + if config_path.exists(): + with open(config_path, "r", encoding="utf-8") as f: + bot_config = toml.load(f) + if "log" in bot_config: + self.log_config.update(bot_config["log"]) + except Exception as e: + print(f"加载bot配置失败: {e}") + + # 从viewer配置文件加载查看器配置 + viewer_config_path = Path("config/log_viewer_config.toml") + self.custom_module_colors = {} + self.custom_level_colors = {} + + try: + if viewer_config_path.exists(): + with open(viewer_config_path, "r", encoding="utf-8") as f: + viewer_config = toml.load(f) + if "viewer" in viewer_config: + self.viewer_config.update(viewer_config["viewer"]) + + # 加载自定义模块颜色 + if "module_colors" in viewer_config["viewer"]: + self.custom_module_colors = viewer_config["viewer"]["module_colors"] + + # 加载自定义级别颜色 + if "level_colors" in viewer_config["viewer"]: + self.custom_level_colors = viewer_config["viewer"]["level_colors"] + + if "log" in viewer_config: + self.log_config.update(viewer_config["log"]) + except Exception as e: + print(f"加载查看器配置失败: {e}") + + # 应用窗口配置 + window_config = self.viewer_config.get("window", {}) + window_width = window_config.get("width", 1200) + window_height = window_config.get("height", 800) + self.root.geometry(f"{window_width}x{window_height}") + + def save_viewer_config(self): + """保存查看器配置""" + # 准备完整的配置数据 + viewer_config_copy = self.viewer_config.copy() + + # 保存自定义颜色(只保存与默认值不同的颜色) + if self.custom_module_colors: + viewer_config_copy["module_colors"] = self.custom_module_colors + if self.custom_level_colors: + viewer_config_copy["level_colors"] = self.custom_level_colors + + config_data = {"log": self.log_config, "viewer": viewer_config_copy} + + config_path = Path("config/log_viewer_config.toml") + config_path.parent.mkdir(exist_ok=True) + + try: + with open(config_path, "w", encoding="utf-8") as f: + toml.dump(config_data, f) + except Exception as e: + print(f"保存查看器配置失败: {e}") + + def create_menu(self): + """创建菜单栏""" + menubar = tk.Menu(self.root) + self.root.config(menu=menubar) + + # 配置菜单 + config_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="配置", menu=config_menu) + config_menu.add_command(label="日志格式设置", command=self.show_format_settings) + config_menu.add_command(label="颜色设置", command=self.show_color_settings) + config_menu.add_command(label="查看器设置", command=self.show_viewer_settings) + config_menu.add_separator() + config_menu.add_command(label="重新加载配置", command=self.reload_config) + + # 文件菜单 + file_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="文件", menu=file_menu) + file_menu.add_command(label="选择日志文件", command=self.select_log_file, accelerator="Ctrl+O") + file_menu.add_command(label="刷新当前文件", command=self.refresh_log_file, accelerator="F5") + file_menu.add_separator() + file_menu.add_command(label="导出当前日志", command=self.export_logs, accelerator="Ctrl+S") + + # 工具菜单 + tools_menu = tk.Menu(menubar, tearoff=0) + menubar.add_cascade(label="工具", menu=tools_menu) + tools_menu.add_command(label="清空日志显示", command=self.clear_log_display) + + def show_format_settings(self): + """显示格式设置窗口""" + format_window = tk.Toplevel(self.root) + format_window.title("日志格式设置") + format_window.geometry("400x300") + + frame = ttk.Frame(format_window) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # 日期格式 + ttk.Label(frame, text="日期格式:").pack(anchor="w", pady=2) + date_style_var = tk.StringVar(value=self.log_config.get("date_style", "m-d H:i:s")) + date_entry = ttk.Entry(frame, textvariable=date_style_var, width=30) + date_entry.pack(anchor="w", pady=2) + ttk.Label(frame, text="格式说明: Y=年份, m=月份, d=日期, H=小时, i=分钟, s=秒", font=("", 8)).pack( + anchor="w", pady=2 + ) + + # 日志级别样式 + ttk.Label(frame, text="日志级别样式:").pack(anchor="w", pady=(10, 2)) + level_style_var = tk.StringVar(value=self.log_config.get("log_level_style", "lite")) + level_frame = ttk.Frame(frame) + level_frame.pack(anchor="w", pady=2) + + ttk.Radiobutton(level_frame, text="简洁(lite)", variable=level_style_var, value="lite").pack( + side="left", padx=(0, 10) + ) + ttk.Radiobutton(level_frame, text="紧凑(compact)", variable=level_style_var, value="compact").pack( + side="left", padx=(0, 10) + ) + ttk.Radiobutton(level_frame, text="完整(full)", variable=level_style_var, value="full").pack( + side="left", padx=(0, 10) + ) + + # 颜色文本设置 + ttk.Label(frame, text="文本颜色设置:").pack(anchor="w", pady=(10, 2)) + color_text_var = tk.StringVar(value=self.log_config.get("color_text", "full")) + color_frame = ttk.Frame(frame) + color_frame.pack(anchor="w", pady=2) + + ttk.Radiobutton(color_frame, text="无颜色(none)", variable=color_text_var, value="none").pack( + side="left", padx=(0, 10) + ) + ttk.Radiobutton(color_frame, text="仅标题(title)", variable=color_text_var, value="title").pack( + side="left", padx=(0, 10) + ) + ttk.Radiobutton(color_frame, text="全部(full)", variable=color_text_var, value="full").pack( + side="left", padx=(0, 10) + ) + + # 按钮 + button_frame = ttk.Frame(frame) + button_frame.pack(fill="x", pady=(20, 0)) + + def apply_format(): + self.log_config["date_style"] = date_style_var.get() + self.log_config["log_level_style"] = level_style_var.get() + self.log_config["color_text"] = color_text_var.get() + + # 重新初始化格式化器 + self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) + self.log_display.formatter = self.formatter + self.log_display.configure_text_tags() + + # 保存配置 + self.save_viewer_config() + + # 重新过滤日志以应用新格式 + self.filter_logs() + + format_window.destroy() + + ttk.Button(button_frame, text="应用", command=apply_format).pack(side="right", padx=(5, 0)) + ttk.Button(button_frame, text="取消", command=format_window.destroy).pack(side="right") + + def show_viewer_settings(self): + """显示查看器设置窗口""" + viewer_window = tk.Toplevel(self.root) + viewer_window.title("查看器设置") + viewer_window.geometry("350x250") + + frame = ttk.Frame(viewer_window) + frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # 主题设置 + ttk.Label(frame, text="主题:").pack(anchor="w", pady=2) + theme_var = tk.StringVar(value=self.viewer_config.get("theme", "dark")) + theme_frame = ttk.Frame(frame) + theme_frame.pack(anchor="w", pady=2) + ttk.Radiobutton(theme_frame, text="深色", variable=theme_var, value="dark").pack(side="left", padx=(0, 10)) + ttk.Radiobutton(theme_frame, text="浅色", variable=theme_var, value="light").pack(side="left") + + # 字体大小 + ttk.Label(frame, text="字体大小:").pack(anchor="w", pady=(10, 2)) + font_size_var = tk.IntVar(value=self.viewer_config.get("font_size", 10)) + font_size_spin = ttk.Spinbox(frame, from_=8, to=20, textvariable=font_size_var, width=10) + font_size_spin.pack(anchor="w", pady=2) + + # 最大行数 + ttk.Label(frame, text="最大显示行数:").pack(anchor="w", pady=(10, 2)) + max_lines_var = tk.IntVar(value=self.viewer_config.get("max_lines", 1000)) + max_lines_spin = ttk.Spinbox(frame, from_=100, to=10000, increment=100, textvariable=max_lines_var, width=10) + max_lines_spin.pack(anchor="w", pady=2) + + # 自动滚动 + auto_scroll_var = tk.BooleanVar(value=self.viewer_config.get("auto_scroll", True)) + ttk.Checkbutton(frame, text="自动滚动到底部", variable=auto_scroll_var).pack(anchor="w", pady=(10, 2)) + + # 按钮 + button_frame = ttk.Frame(frame) + button_frame.pack(fill="x", pady=(20, 0)) + + def apply_viewer_settings(): + self.viewer_config["theme"] = theme_var.get() + self.viewer_config["font_size"] = font_size_var.get() + self.viewer_config["max_lines"] = max_lines_var.get() + self.viewer_config["auto_scroll"] = auto_scroll_var.get() + + # 应用主题 + self.apply_theme() + + # 保存配置 + self.save_viewer_config() + + viewer_window.destroy() + + ttk.Button(button_frame, text="应用", command=apply_viewer_settings).pack(side="right", padx=(5, 0)) + ttk.Button(button_frame, text="取消", command=viewer_window.destroy).pack(side="right") + + def apply_theme(self): + """应用主题设置""" + theme = self.viewer_config.get("theme", "dark") + font_size = self.viewer_config.get("font_size", 10) + + # 更新虚拟显示组件的主题 + if theme == "dark": + bg_color = "#1e1e1e" + fg_color = "#ffffff" + select_bg = "#404040" + else: + bg_color = "#ffffff" + fg_color = "#000000" + select_bg = "#c0c0c0" + + self.log_display.text_widget.config( + background=bg_color, foreground=fg_color, selectbackground=select_bg, font=("Consolas", font_size) + ) + + # 重新配置标签样式 + self.log_display.configure_text_tags() + + def reload_config(self): + """重新加载配置""" + self.load_config() + self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) + self.log_display.formatter = self.formatter + self.log_display.configure_text_tags() + self.apply_theme() + self.filter_logs() + + def clear_log_display(self): + """清空日志显示""" + self.log_display.text_widget.delete(1.0, tk.END) + + def export_logs(self): + """导出当前显示的日志""" + filename = filedialog.asksaveasfilename( + defaultextension=".txt", filetypes=[("文本文件", "*.txt"), ("所有文件", "*.*")] + ) + if filename: + try: + # 获取当前显示的所有日志条目 + if self.log_index: + filtered_count = self.log_index.get_filtered_count() + log_lines = [] + for i in range(filtered_count): + log_entry = self.log_index.get_entry_at_filtered_position(i) + if log_entry: + parts, tags = self.formatter.format_log_entry(log_entry) + line_text = " ".join(parts) + log_lines.append(line_text) + + with open(filename, "w", encoding="utf-8") as f: + f.write("\n".join(log_lines)) + messagebox.showinfo("导出成功", f"日志已导出到: {filename}") + else: + messagebox.showwarning("导出失败", "没有日志可导出") + except Exception as e: + messagebox.showerror("导出失败", f"导出日志时出错: {e}") + + def load_module_mapping(self): + """加载自定义模块映射""" + mapping_file = Path("config/module_mapping.json") + if mapping_file.exists(): + try: + with open(mapping_file, "r", encoding="utf-8") as f: + custom_mapping = json.load(f) + self.module_name_mapping.update(custom_mapping) + except Exception as e: + print(f"加载模块映射失败: {e}") + + def save_module_mapping(self): + """保存自定义模块映射""" + mapping_file = Path("config/module_mapping.json") + mapping_file.parent.mkdir(exist_ok=True) + try: + with open(mapping_file, "w", encoding="utf-8") as f: + json.dump(self.module_name_mapping, f, ensure_ascii=False, indent=2) + except Exception as e: + print(f"保存模块映射失败: {e}") + + def show_color_settings(self): + """显示颜色设置窗口""" + color_window = tk.Toplevel(self.root) + color_window.title("颜色设置") + color_window.geometry("300x400") + + # 创建滚动框架 + frame = ttk.Frame(color_window) + frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) + + # 创建滚动条 + scrollbar = ttk.Scrollbar(frame) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 创建颜色设置列表 + canvas = tk.Canvas(frame, yscrollcommand=scrollbar.set) + canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.config(command=canvas.yview) + + # 创建内部框架 + inner_frame = ttk.Frame(canvas) + canvas.create_window((0, 0), window=inner_frame, anchor="nw") + + # 添加日志级别颜色设置 + ttk.Label(inner_frame, text="日志级别颜色", font=("", 10, "bold")).pack(anchor="w", padx=5, pady=5) + for level in ["info", "warning", "error"]: + frame = ttk.Frame(inner_frame) + frame.pack(fill=tk.X, padx=5, pady=2) + ttk.Label(frame, text=level).pack(side=tk.LEFT) + color_btn = ttk.Button( + frame, text="选择颜色", command=lambda level_name=level: self.choose_color(level_name) + ) + color_btn.pack(side=tk.RIGHT) + # 显示当前颜色 + color_label = ttk.Label(frame, text="■", foreground=self.formatter.level_colors[level]) + color_label.pack(side=tk.RIGHT, padx=5) + + # 添加模块颜色设置 + ttk.Label(inner_frame, text="\n模块颜色", font=("", 10, "bold")).pack(anchor="w", padx=5, pady=5) + for module in sorted(self.modules): + frame = ttk.Frame(inner_frame) + frame.pack(fill=tk.X, padx=5, pady=2) + ttk.Label(frame, text=module).pack(side=tk.LEFT) + color_btn = ttk.Button(frame, text="选择颜色", command=lambda m=module: self.choose_module_color(m)) + color_btn.pack(side=tk.RIGHT) + # 显示当前颜色 + color = self.formatter.module_colors.get(module, "black") + color_label = ttk.Label(frame, text="■", foreground=color) + color_label.pack(side=tk.RIGHT, padx=5) + + # 更新画布滚动区域 + inner_frame.update_idletasks() + canvas.config(scrollregion=canvas.bbox("all")) + + # 添加确定按钮 + ttk.Button(color_window, text="确定", command=color_window.destroy).pack(pady=5) + + def choose_color(self, level): + """选择日志级别颜色""" + color = colorchooser.askcolor(color=self.formatter.level_colors[level])[1] + if color: + self.formatter.level_colors[level] = color + self.custom_level_colors[level] = color # 保存到自定义颜色 + self.log_display.formatter = self.formatter + self.log_display.configure_text_tags() + self.save_viewer_config() # 自动保存配置 + self.filter_logs() + + def choose_module_color(self, module): + """选择模块颜色""" + color = colorchooser.askcolor(color=self.formatter.module_colors.get(module, "black"))[1] + if color: + self.formatter.module_colors[module] = color + self.custom_module_colors[module] = color # 保存到自定义颜色 + self.log_display.formatter = self.formatter + self.log_display.configure_text_tags() + self.save_viewer_config() # 自动保存配置 + self.filter_logs() + + def create_control_panel(self): + """创建控制面板""" + # 控制面板 + self.control_frame = ttk.Frame(self.main_frame) + self.control_frame.pack(fill=tk.X, pady=(0, 5)) + + # 文件选择框架 + self.file_frame = ttk.LabelFrame(self.control_frame, text="日志文件") + self.file_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=(0, 5)) + + # 当前文件显示 + self.current_file_var = tk.StringVar(value=str(self.current_log_file)) + self.file_label = ttk.Label(self.file_frame, textvariable=self.current_file_var, foreground="blue") + self.file_label.pack(side=tk.LEFT, padx=5, pady=2) + + # 进度条 + self.progress_var = tk.DoubleVar() + self.progress_bar = ttk.Progressbar(self.file_frame, variable=self.progress_var, length=200) + self.progress_bar.pack(side=tk.LEFT, padx=5, pady=2) + self.progress_bar.pack_forget() + + # 状态标签 + self.status_var = tk.StringVar(value="就绪") + self.status_label = ttk.Label(self.file_frame, textvariable=self.status_var) + self.status_label.pack(side=tk.LEFT, padx=5, pady=2) + + # 按钮区域 + button_frame = ttk.Frame(self.file_frame) + button_frame.pack(side=tk.RIGHT, padx=5, pady=2) + + ttk.Button(button_frame, text="选择文件", command=self.select_log_file).pack(side=tk.LEFT, padx=2) + ttk.Button(button_frame, text="刷新", command=self.refresh_log_file).pack(side=tk.LEFT, padx=2) + ttk.Checkbutton(button_frame, text="实时更新", variable=self.is_watching, command=self.toggle_watching).pack( + side=tk.LEFT, padx=2 + ) + + # 模块选择框架 + self.module_frame = ttk.LabelFrame(self.control_frame, text="模块") + self.module_frame.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=5) + + # 创建模块选择滚动区域 + self.module_canvas = tk.Canvas(self.module_frame, height=80) + self.module_canvas.pack(side=tk.LEFT, fill=tk.X, expand=True) + + # 创建模块选择内部框架 + self.module_inner_frame = ttk.Frame(self.module_canvas) + self.module_canvas.create_window((0, 0), window=self.module_inner_frame, anchor="nw") + + # 创建右侧控制区域(级别和搜索) + self.right_control_frame = ttk.Frame(self.control_frame) + self.right_control_frame.pack(side=tk.RIGHT, padx=5) + + # 映射编辑按钮 + mapping_btn = ttk.Button(self.right_control_frame, text="模块映射", command=self.edit_module_mapping) + mapping_btn.pack(side=tk.TOP, fill=tk.X, pady=1) + + # 日志级别选择 + level_frame = ttk.Frame(self.right_control_frame) + level_frame.pack(side=tk.TOP, fill=tk.X, pady=1) + ttk.Label(level_frame, text="级别:").pack(side=tk.LEFT, padx=2) + self.level_var = tk.StringVar(value="全部") + self.level_combo = ttk.Combobox(level_frame, textvariable=self.level_var, width=8) + self.level_combo["values"] = ["全部", "debug", "info", "warning", "error", "critical"] + self.level_combo.pack(side=tk.LEFT, padx=2) + + # 搜索框 + search_frame = ttk.Frame(self.right_control_frame) + search_frame.pack(side=tk.TOP, fill=tk.X, pady=1) + ttk.Label(search_frame, text="搜索:").pack(side=tk.LEFT, padx=2) + self.search_var = tk.StringVar() + self.search_entry = ttk.Entry(search_frame, textvariable=self.search_var, width=15) + self.search_entry.pack(side=tk.LEFT, padx=2) + + def on_file_loaded(self, log_index, error): + """文件加载完成回调""" + self.progress_bar.pack_forget() + + if error: + self.status_var.set(f"加载失败: {error}") + messagebox.showerror("错误", f"加载日志文件失败: {error}") + return + + self.log_index = log_index + try: + self.last_file_size = os.path.getsize(self.current_log_file) + except OSError: + self.last_file_size = 0 + self.status_var.set(f"已加载 {log_index.total_entries} 条日志") + + # 更新模块列表 + self.modules = set(log_index.module_index.keys()) + self.update_module_list() + + # 应用过滤并显示 + self.filter_logs() + + # 如果开启了实时更新,则开始监视 + if self.is_watching.get(): + self.start_watching() + + def on_loading_progress(self, progress, line_count): + """加载进度回调""" + self.root.after(0, lambda: self.update_progress(progress, line_count)) + + def update_progress(self, progress, line_count): + """更新进度显示""" + self.progress_var.set(progress) + self.status_var.set(f"正在加载... {line_count} 条 ({progress:.1f}%)") + + def load_log_file_async(self): + """异步加载日志文件""" + self.stop_watching() # 停止任何正在运行的监视器 + + if not self.current_log_file.exists(): + self.status_var.set("文件不存在") + return + + # 显示进度条 + self.progress_bar.pack(side=tk.LEFT, padx=5, pady=2, before=self.status_label) + self.progress_var.set(0) + self.status_var.set("正在加载...") + + # 清空当前数据 + self.log_index = LogIndex() + self.selected_modules.clear() + + # 开始异步加载 + self.async_loader.load_file_async(str(self.current_log_file), self.on_loading_progress) + + def filter_logs(self, *args): + """过滤日志""" + if not self.log_index: + return + + # 获取过滤条件 + selected_modules = self.selected_modules if self.selected_modules else None + level = self.level_var.get() if self.level_var.get() != "全部" else None + search_text = self.search_var.get().strip() if self.search_var.get().strip() else None + + # 应用过滤 + self.log_index.filter_entries(selected_modules, level, search_text) + + # 更新显示 + self.log_display.set_log_index(self.log_index) + + # 更新状态 + filtered_count = self.log_index.get_filtered_count() + total_count = self.log_index.total_entries + if filtered_count == total_count: + self.status_var.set(f"显示 {total_count} 条日志") + else: + self.status_var.set(f"显示 {filtered_count}/{total_count} 条日志") + + def select_log_file(self): + """选择日志文件""" + filename = filedialog.askopenfilename( + title="选择日志文件", + filetypes=[("JSONL日志文件", "*.jsonl"), ("所有文件", "*.*")], + initialdir="logs" if Path("logs").exists() else ".", + ) + if filename: + new_file = Path(filename) + if new_file != self.current_log_file: + self.current_log_file = new_file + self.current_file_var.set(str(self.current_log_file)) + self.load_log_file_async() + + def refresh_log_file(self): + """刷新日志文件""" + self.load_log_file_async() + + def toggle_watching(self): + """切换实时更新状态""" + if self.is_watching.get(): + self.start_watching() + else: + self.stop_watching() + + def start_watching(self): + """开始监视文件变化""" + if self.watching_thread and self.watching_thread.is_alive(): + return # 已经在监视 + + if not self.current_log_file.exists(): + self.is_watching.set(False) + messagebox.showwarning("警告", "日志文件不存在,无法开启实时更新。") + return + + self.watching_thread = threading.Thread(target=self.watch_file_loop, daemon=True) + self.watching_thread.start() + + def stop_watching(self): + """停止监视文件变化""" + self.is_watching.set(False) + # 线程通过检查 is_watching 变量来停止,这里不需要强制干预 + self.watching_thread = None + + def watch_file_loop(self): + """监视文件循环""" + while self.is_watching.get(): + try: + if not self.current_log_file.exists(): + self.root.after( + 0, + lambda: messagebox.showwarning("警告", "日志文件丢失,已停止实时更新。"), + ) + self.root.after(0, self.is_watching.set, False) + break + + current_size = os.path.getsize(self.current_log_file) + if current_size > self.last_file_size: + new_entries = self.read_new_logs(self.last_file_size) + self.last_file_size = current_size + if new_entries: + self.root.after(0, self.append_new_logs, new_entries) + elif current_size < self.last_file_size: + # 文件被截断或替换 + self.last_file_size = 0 + self.root.after(0, self.refresh_log_file) + break # 刷新会重新启动监视(如果需要),所以结束当前循环 + + except Exception as e: + print(f"监视日志文件时出错: {e}") + self.root.after(0, self.is_watching.set, False) + break + + time.sleep(1) + + self.watching_thread = None + + def read_new_logs(self, from_position): + """读取新的日志条目并返回它们""" + new_entries = [] + new_modules = set() # 收集新发现的模块 + with open(self.current_log_file, "r", encoding="utf-8") as f: + f.seek(from_position) + line_count = self.log_index.total_entries + for line in f: + if line.strip(): + try: + log_entry = json.loads(line) + self.log_index.add_entry(line_count, log_entry) + new_entries.append(log_entry) + + logger_name = log_entry.get("logger_name", "") + if logger_name and logger_name not in self.modules: + new_modules.add(logger_name) + + line_count += 1 + except json.JSONDecodeError: + continue + + # 如果发现了新模块,在主线程中更新模块集合 + if new_modules: + def update_modules(): + self.modules.update(new_modules) + self.update_module_list() + + self.root.after(0, update_modules) + + return new_entries + + def append_new_logs(self, new_entries): + """将新日志附加到显示中""" + # 检查是否应附加或执行完全刷新(例如,如果过滤器处于活动状态) + selected_modules = ( + self.selected_modules if (self.selected_modules and "全部" not in self.selected_modules) else None + ) + level = self.level_var.get() if self.level_var.get() != "全部" else None + search_text = self.search_var.get().strip() if self.search_var.get().strip() else None + + is_filtered = selected_modules or level or search_text + + if is_filtered: + # 如果过滤器处于活动状态,我们必须执行完全刷新以应用它们 + self.filter_logs() + return + + # 如果没有过滤器,只需附加新日志 + for entry in new_entries: + self.log_display.append_entry(entry) + + # 更新状态 + total_count = self.log_index.total_entries + self.status_var.set(f"显示 {total_count} 条日志") + + def update_module_list(self): + """更新模块列表""" + # 清空现有选项 + for widget in self.module_inner_frame.winfo_children(): + widget.destroy() + + # 计算总模块数(包括"全部") + total_modules = len(self.modules) + 1 + max_cols = min(4, max(2, total_modules)) # 减少最大列数,避免超出边界 + + # 配置网格列权重,让每列平均分配空间 + for i in range(max_cols): + self.module_inner_frame.grid_columnconfigure(i, weight=1, uniform="module_col") + + # 创建一个多行布局 + current_row = 0 + current_col = 0 + + # 添加"全部"选项 + all_frame = ttk.Frame(self.module_inner_frame) + all_frame.grid(row=current_row, column=current_col, padx=3, pady=2, sticky="ew") + + all_var = tk.BooleanVar(value="全部" in self.selected_modules) + all_check = ttk.Checkbutton( + all_frame, text="全部", variable=all_var, command=lambda: self.toggle_module("全部", all_var) + ) + all_check.pack(side=tk.LEFT) + + # 使用颜色标签替代按钮 + all_color = self.formatter.module_colors.get("全部", "black") + all_color_label = ttk.Label(all_frame, text="■", foreground=all_color, width=2, cursor="hand2") + all_color_label.pack(side=tk.LEFT, padx=2) + all_color_label.bind("", lambda e: self.choose_module_color("全部")) + + current_col += 1 + + # 添加其他模块选项 + for module in sorted(self.modules): + if current_col >= max_cols: + current_row += 1 + current_col = 0 + + frame = ttk.Frame(self.module_inner_frame) + frame.grid(row=current_row, column=current_col, padx=3, pady=2, sticky="ew") + + var = tk.BooleanVar(value=module in self.selected_modules) + + # 使用中文映射名称显示 + display_name = self.get_display_name(module) + if len(display_name) > 12: + display_name = display_name[:10] + "..." + + check = ttk.Checkbutton( + frame, text=display_name, variable=var, command=lambda m=module, v=var: self.toggle_module(m, v) + ) + check.pack(side=tk.LEFT) + + # 添加工具提示显示完整名称和英文名 + full_tooltip = f"{self.get_display_name(module)}" + if module != self.get_display_name(module): + full_tooltip += f"\n({module})" + self.create_tooltip(check, full_tooltip) + + # 使用颜色标签替代按钮 + color = self.formatter.module_colors.get(module, "black") + color_label = ttk.Label(frame, text="■", foreground=color, width=2, cursor="hand2") + color_label.pack(side=tk.LEFT, padx=2) + color_label.bind("", lambda e, m=module: self.choose_module_color(m)) + + current_col += 1 + + # 更新画布滚动区域 + self.module_inner_frame.update_idletasks() + self.module_canvas.config(scrollregion=self.module_canvas.bbox("all")) + + # 添加垂直滚动条 + if not hasattr(self, "module_scrollbar"): + self.module_scrollbar = ttk.Scrollbar( + self.module_frame, orient=tk.VERTICAL, command=self.module_canvas.yview + ) + self.module_scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self.module_canvas.config(yscrollcommand=self.module_scrollbar.set) + + def create_tooltip(self, widget, text): + """为控件创建工具提示""" + + def on_enter(event): + tooltip = tk.Toplevel() + tooltip.wm_overrideredirect(True) + tooltip.wm_geometry(f"+{event.x_root + 10}+{event.y_root + 10}") + label = ttk.Label(tooltip, text=text, background="lightyellow", relief="solid", borderwidth=1) + label.pack() + widget.tooltip = tooltip + + def on_leave(event): + if hasattr(widget, "tooltip"): + widget.tooltip.destroy() + del widget.tooltip + + widget.bind("", on_enter) + widget.bind("", on_leave) + + def toggle_module(self, module, var): + """切换模块选择状态""" + if module == "全部": + if var.get(): + self.selected_modules = {"全部"} + else: + self.selected_modules.clear() + else: + if var.get(): + self.selected_modules.add(module) + if "全部" in self.selected_modules: + self.selected_modules.remove("全部") + else: + self.selected_modules.discard(module) + + self.filter_logs() + + def get_display_name(self, module_name): + """获取模块的显示名称""" + return self.module_name_mapping.get(module_name, module_name) + + def edit_module_mapping(self): + """编辑模块映射""" + mapping_window = tk.Toplevel(self.root) + mapping_window.title("编辑模块映射") + mapping_window.geometry("500x600") + + # 创建滚动框架 + frame = ttk.Frame(mapping_window) + frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) + + # 创建滚动条 + scrollbar = ttk.Scrollbar(frame) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + + # 创建映射编辑列表 + canvas = tk.Canvas(frame, yscrollcommand=scrollbar.set) + canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + scrollbar.config(command=canvas.yview) + + # 创建内部框架 + inner_frame = ttk.Frame(canvas) + canvas.create_window((0, 0), window=inner_frame, anchor="nw") + + # 添加标题 + ttk.Label(inner_frame, text="模块映射编辑", font=("", 12, "bold")).pack(anchor="w", padx=5, pady=5) + ttk.Label(inner_frame, text="英文名 -> 中文名", font=("", 10)).pack(anchor="w", padx=5, pady=2) + + # 映射编辑字典 + mapping_vars = {} + + # 添加现有模块的映射编辑 + all_modules = sorted(self.modules) + for module in all_modules: + frame_row = ttk.Frame(inner_frame) + frame_row.pack(fill=tk.X, padx=5, pady=2) + + ttk.Label(frame_row, text=module, width=20).pack(side=tk.LEFT, padx=5) + ttk.Label(frame_row, text="->").pack(side=tk.LEFT, padx=5) + + var = tk.StringVar(value=self.module_name_mapping.get(module, module)) + mapping_vars[module] = var + entry = ttk.Entry(frame_row, textvariable=var, width=25) + entry.pack(side=tk.LEFT, padx=5) + + # 更新画布滚动区域 + inner_frame.update_idletasks() + canvas.config(scrollregion=canvas.bbox("all")) + + def save_mappings(): + # 更新映射 + for module, var in mapping_vars.items(): + new_name = var.get().strip() + if new_name and new_name != module: + self.module_name_mapping[module] = new_name + elif module in self.module_name_mapping and not new_name: + del self.module_name_mapping[module] + + # 保存到文件 + self.save_module_mapping() + # 更新模块列表显示 + self.update_module_list() + mapping_window.destroy() + + # 添加按钮 + button_frame = ttk.Frame(mapping_window) + button_frame.pack(fill=tk.X, padx=5, pady=5) + ttk.Button(button_frame, text="保存", command=save_mappings).pack(side=tk.RIGHT, padx=5) + ttk.Button(button_frame, text="取消", command=mapping_window.destroy).pack(side=tk.RIGHT, padx=5) + + +def main(): + root = tk.Tk() + LogViewer(root) + root.mainloop() + + +if __name__ == "__main__": + main() + diff --git a/scripts/manifest_tool.py b/scripts/manifest_tool.py new file mode 100644 index 000000000..8312dc3e4 --- /dev/null +++ b/scripts/manifest_tool.py @@ -0,0 +1,237 @@ +""" +插件Manifest管理命令行工具 + +提供插件manifest文件的创建、验证和管理功能 +""" + +import os +import sys +import argparse +import json +from pathlib import Path +from src.common.logger import get_logger +from src.plugin_system.utils.manifest_utils import ( + ManifestValidator, +) + +# 添加项目根目录到Python路径 +project_root = Path(__file__).parent.parent.parent.parent +sys.path.insert(0, str(project_root)) + + +logger = get_logger("manifest_tool") + + +def create_minimal_manifest(plugin_dir: str, plugin_name: str, description: str = "", author: str = "") -> bool: + """创建最小化的manifest文件 + + Args: + plugin_dir: 插件目录 + plugin_name: 插件名称 + description: 插件描述 + author: 插件作者 + + Returns: + bool: 是否创建成功 + """ + manifest_path = os.path.join(plugin_dir, "_manifest.json") + + if os.path.exists(manifest_path): + print(f"❌ Manifest文件已存在: {manifest_path}") + return False + + # 创建最小化manifest + minimal_manifest = { + "manifest_version": 1, + "name": plugin_name, + "version": "1.0.0", + "description": description or f"{plugin_name}插件", + "author": {"name": author or "Unknown"}, + } + + try: + with open(manifest_path, "w", encoding="utf-8") as f: + json.dump(minimal_manifest, f, ensure_ascii=False, indent=2) + print(f"✅ 已创建最小化manifest文件: {manifest_path}") + return True + except Exception as e: + print(f"❌ 创建manifest文件失败: {e}") + return False + + +def create_complete_manifest(plugin_dir: str, plugin_name: str) -> bool: + """创建完整的manifest模板文件 + + Args: + plugin_dir: 插件目录 + plugin_name: 插件名称 + + Returns: + bool: 是否创建成功 + """ + manifest_path = os.path.join(plugin_dir, "_manifest.json") + + if os.path.exists(manifest_path): + print(f"❌ Manifest文件已存在: {manifest_path}") + return False + + # 创建完整模板 + complete_manifest = { + "manifest_version": 1, + "name": plugin_name, + "version": "1.0.0", + "description": f"{plugin_name}插件描述", + "author": {"name": "插件作者", "url": "https://github.com/your-username"}, + "license": "MIT", + "host_application": {"min_version": "1.0.0", "max_version": "4.0.0"}, + "homepage_url": "https://github.com/your-repo", + "repository_url": "https://github.com/your-repo", + "keywords": ["keyword1", "keyword2"], + "categories": ["Category1"], + "default_locale": "zh-CN", + "locales_path": "_locales", + "plugin_info": { + "is_built_in": False, + "plugin_type": "general", + "components": [{"type": "action", "name": "sample_action", "description": "示例动作组件"}], + }, + } + + try: + with open(manifest_path, "w", encoding="utf-8") as f: + json.dump(complete_manifest, f, ensure_ascii=False, indent=2) + print(f"✅ 已创建完整manifest模板: {manifest_path}") + print("💡 请根据实际情况修改manifest文件中的内容") + return True + except Exception as e: + print(f"❌ 创建manifest文件失败: {e}") + return False + + +def validate_manifest_file(plugin_dir: str) -> bool: + """验证manifest文件 + + Args: + plugin_dir: 插件目录 + + Returns: + bool: 是否验证通过 + """ + manifest_path = os.path.join(plugin_dir, "_manifest.json") + + if not os.path.exists(manifest_path): + print(f"❌ 未找到manifest文件: {manifest_path}") + return False + + try: + with open(manifest_path, "r", encoding="utf-8") as f: + manifest_data = json.load(f) + + validator = ManifestValidator() + is_valid = validator.validate_manifest(manifest_data) + + # 显示验证结果 + print("📋 Manifest验证结果:") + print(validator.get_validation_report()) + + if is_valid: + print("✅ Manifest文件验证通过") + else: + print("❌ Manifest文件验证失败") + + return is_valid + + except json.JSONDecodeError as e: + print(f"❌ Manifest文件格式错误: {e}") + return False + except Exception as e: + print(f"❌ 验证过程中发生错误: {e}") + return False + + +def scan_plugins_without_manifest(root_dir: str) -> None: + """扫描缺少manifest文件的插件 + + Args: + root_dir: 扫描的根目录 + """ + print(f"🔍 扫描目录: {root_dir}") + + plugins_without_manifest = [] + + for root, dirs, files in os.walk(root_dir): + # 跳过隐藏目录和__pycache__ + dirs[:] = [d for d in dirs if not d.startswith(".") and d != "__pycache__"] + + # 检查是否包含plugin.py文件(标识为插件目录) + if "plugin.py" in files: + manifest_path = os.path.join(root, "_manifest.json") + if not os.path.exists(manifest_path): + plugins_without_manifest.append(root) + + if plugins_without_manifest: + print(f"❌ 发现 {len(plugins_without_manifest)} 个插件缺少manifest文件:") + for plugin_dir in plugins_without_manifest: + plugin_name = os.path.basename(plugin_dir) + print(f" - {plugin_name}: {plugin_dir}") + print("💡 使用 'python manifest_tool.py create-minimal <插件目录>' 创建manifest文件") + else: + print("✅ 所有插件都有manifest文件") + + +def main(): + """主函数""" + parser = argparse.ArgumentParser(description="插件Manifest管理工具") + subparsers = parser.add_subparsers(dest="command", help="可用命令") + + # 创建最小化manifest命令 + create_minimal_parser = subparsers.add_parser("create-minimal", help="创建最小化manifest文件") + create_minimal_parser.add_argument("plugin_dir", help="插件目录路径") + create_minimal_parser.add_argument("--name", help="插件名称") + create_minimal_parser.add_argument("--description", help="插件描述") + create_minimal_parser.add_argument("--author", help="插件作者") + + # 创建完整manifest命令 + create_complete_parser = subparsers.add_parser("create-complete", help="创建完整manifest模板") + create_complete_parser.add_argument("plugin_dir", help="插件目录路径") + create_complete_parser.add_argument("--name", help="插件名称") + + # 验证manifest命令 + validate_parser = subparsers.add_parser("validate", help="验证manifest文件") + validate_parser.add_argument("plugin_dir", help="插件目录路径") + + # 扫描插件命令 + scan_parser = subparsers.add_parser("scan", help="扫描缺少manifest的插件") + scan_parser.add_argument("root_dir", help="扫描的根目录路径") + + args = parser.parse_args() + + if not args.command: + parser.print_help() + return + + try: + if args.command == "create-minimal": + plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir)) + success = create_minimal_manifest(args.plugin_dir, plugin_name, args.description or "", args.author or "") + sys.exit(0 if success else 1) + + elif args.command == "create-complete": + plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir)) + success = create_complete_manifest(args.plugin_dir, plugin_name) + sys.exit(0 if success else 1) + + elif args.command == "validate": + success = validate_manifest_file(args.plugin_dir) + sys.exit(0 if success else 1) + + elif args.command == "scan": + scan_plugins_without_manifest(args.root_dir) + + except Exception as e: + print(f"❌ 执行命令时发生错误: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py new file mode 100644 index 000000000..8b89a668a --- /dev/null +++ b/scripts/mongodb_to_sqlite.py @@ -0,0 +1,920 @@ +import os +import json +import sys # 新增系统模块导入 + +# import time +import pickle +from pathlib import Path + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from typing import Dict, Any, List, Optional, Type +from dataclasses import dataclass, field +from datetime import datetime +from pymongo import MongoClient +from pymongo.errors import ConnectionFailure +from peewee import Model, Field, IntegrityError + +# Rich 进度条和显示组件 +from rich.console import Console +from rich.progress import ( + Progress, + TextColumn, + BarColumn, + TaskProgressColumn, + TimeRemainingColumn, + TimeElapsedColumn, + SpinnerColumn, +) +from rich.table import Table +from rich.panel import Panel +# from rich.text import Text + +from src.common.database.database import db +from src.common.database.database_model import ( + ChatStreams, + Emoji, + Messages, + Images, + ImageDescriptions, + PersonInfo, + Knowledges, + ThinkingLog, + GraphNodes, + GraphEdges, +) +from src.common.logger import get_logger + +logger = get_logger("mongodb_to_sqlite") + +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + + +@dataclass +class MigrationConfig: + """迁移配置类""" + + mongo_collection: str + target_model: Type[Model] + field_mapping: Dict[str, str] + batch_size: int = 500 + enable_validation: bool = True + skip_duplicates: bool = True + unique_fields: List[str] = field(default_factory=list) # 用于重复检查的字段 + + +# 数据验证相关类已移除 - 用户要求不要数据验证 + + +@dataclass +class MigrationCheckpoint: + """迁移断点数据""" + + collection_name: str + processed_count: int + last_processed_id: Any + timestamp: datetime + batch_errors: List[Dict[str, Any]] = field(default_factory=list) + + +@dataclass +class MigrationStats: + """迁移统计信息""" + + total_documents: int = 0 + processed_count: int = 0 + success_count: int = 0 + error_count: int = 0 + skipped_count: int = 0 + duplicate_count: int = 0 + validation_errors: int = 0 + batch_insert_count: int = 0 + errors: List[Dict[str, Any]] = field(default_factory=list) + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None + + def add_error(self, doc_id: Any, error: str, doc_data: Optional[Dict] = None): + """添加错误记录""" + self.errors.append( + {"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data} + ) + self.error_count += 1 + + def add_validation_error(self, doc_id: Any, field: str, error: str): + """添加验证错误""" + self.add_error(doc_id, f"验证失败 - {field}: {error}") + self.validation_errors += 1 + + +class MongoToSQLiteMigrator: + """MongoDB到SQLite数据迁移器 - 使用Peewee ORM""" + + def __init__(self, mongo_uri: Optional[str] = None, database_name: Optional[str] = None): + self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot") + self.mongo_uri = mongo_uri or self._build_mongo_uri() + self.mongo_client: Optional[MongoClient] = None + self.mongo_db = None + + # 迁移配置 + self.migration_configs = self._initialize_migration_configs() + + # 进度条控制台 + self.console = Console() + # 检查点目录 + self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints")) + self.checkpoint_dir.mkdir(exist_ok=True) + + # 验证规则已禁用 + self.validation_rules = self._initialize_validation_rules() + + def _build_mongo_uri(self) -> str: + """构建MongoDB连接URI""" + if mongo_uri := os.getenv("MONGODB_URI"): + return mongo_uri + + user = os.getenv("MONGODB_USER") + password = os.getenv("MONGODB_PASS") + host = os.getenv("MONGODB_HOST", "localhost") + port = os.getenv("MONGODB_PORT", "27017") + auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin") + + if user and password: + return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}" + else: + return f"mongodb://{host}:{port}/{self.database_name}" + + def _initialize_migration_configs(self) -> List[MigrationConfig]: + """初始化迁移配置""" + return [ # 表情包迁移配置 + MigrationConfig( + mongo_collection="emoji", + target_model=Emoji, + field_mapping={ + "full_path": "full_path", + "format": "format", + "hash": "emoji_hash", + "description": "description", + "emotion": "emotion", + "usage_count": "usage_count", + "last_used_time": "last_used_time", + # record_time字段将在转换时自动设置为当前时间 + }, + enable_validation=False, # 禁用数据验证 + unique_fields=["full_path", "emoji_hash"], + ), + # 聊天流迁移配置 + MigrationConfig( + mongo_collection="chat_streams", + target_model=ChatStreams, + field_mapping={ + "stream_id": "stream_id", + "create_time": "create_time", + "group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。 + "group_info.group_id": "group_id", # 同上 + "group_info.group_name": "group_name", # 同上 + "last_active_time": "last_active_time", + "platform": "platform", + "user_info.platform": "user_platform", + "user_info.user_id": "user_id", + "user_info.user_nickname": "user_nickname", + "user_info.user_cardname": "user_cardname", + }, + enable_validation=False, # 禁用数据验证 + unique_fields=["stream_id"], + ), + # 消息迁移配置 + MigrationConfig( + mongo_collection="messages", + target_model=Messages, + field_mapping={ + "message_id": "message_id", + "time": "time", + "chat_id": "chat_id", + "chat_info.stream_id": "chat_info_stream_id", + "chat_info.platform": "chat_info_platform", + "chat_info.user_info.platform": "chat_info_user_platform", + "chat_info.user_info.user_id": "chat_info_user_id", + "chat_info.user_info.user_nickname": "chat_info_user_nickname", + "chat_info.user_info.user_cardname": "chat_info_user_cardname", + "chat_info.group_info.platform": "chat_info_group_platform", + "chat_info.group_info.group_id": "chat_info_group_id", + "chat_info.group_info.group_name": "chat_info_group_name", + "chat_info.create_time": "chat_info_create_time", + "chat_info.last_active_time": "chat_info_last_active_time", + "user_info.platform": "user_platform", + "user_info.user_id": "user_id", + "user_info.user_nickname": "user_nickname", + "user_info.user_cardname": "user_cardname", + "processed_plain_text": "processed_plain_text", + "memorized_times": "memorized_times", + }, + enable_validation=False, # 禁用数据验证 + unique_fields=["message_id"], + ), + # 图片迁移配置 + MigrationConfig( + mongo_collection="images", + target_model=Images, + field_mapping={ + "hash": "emoji_hash", + "description": "description", + "path": "path", + "timestamp": "timestamp", + "type": "type", + }, + unique_fields=["path"], + ), + # 图片描述迁移配置 + MigrationConfig( + mongo_collection="image_descriptions", + target_model=ImageDescriptions, + field_mapping={ + "type": "type", + "hash": "image_description_hash", + "description": "description", + "timestamp": "timestamp", + }, + unique_fields=["image_description_hash", "type"], + ), + # 个人信息迁移配置 + MigrationConfig( + mongo_collection="person_info", + target_model=PersonInfo, + field_mapping={ + "person_id": "person_id", + "person_name": "person_name", + "name_reason": "name_reason", + "platform": "platform", + "user_id": "user_id", + "nickname": "nickname", + "relationship_value": "relationship_value", + "konw_time": "know_time", + }, + unique_fields=["person_id"], + ), + # 知识库迁移配置 + MigrationConfig( + mongo_collection="knowledges", + target_model=Knowledges, + field_mapping={"content": "content", "embedding": "embedding"}, + unique_fields=["content"], # 假设内容唯一 + ), + # 思考日志迁移配置 + MigrationConfig( + mongo_collection="thinking_log", + target_model=ThinkingLog, + field_mapping={ + "chat_id": "chat_id", + "trigger_text": "trigger_text", + "response_text": "response_text", + "trigger_info": "trigger_info_json", + "response_info": "response_info_json", + "timing_results": "timing_results_json", + "chat_history": "chat_history_json", + "chat_history_in_thinking": "chat_history_in_thinking_json", + "chat_history_after_response": "chat_history_after_response_json", + "heartflow_data": "heartflow_data_json", + "reasoning_data": "reasoning_data_json", + }, + unique_fields=["chat_id", "trigger_text"], + ), + # 图节点迁移配置 + MigrationConfig( + mongo_collection="graph_data.nodes", + target_model=GraphNodes, + field_mapping={ + "concept": "concept", + "memory_items": "memory_items", + "hash": "hash", + "created_time": "created_time", + "last_modified": "last_modified", + }, + unique_fields=["concept"], + ), + # 图边迁移配置 + MigrationConfig( + mongo_collection="graph_data.edges", + target_model=GraphEdges, + field_mapping={ + "source": "source", + "target": "target", + "strength": "strength", + "hash": "hash", + "created_time": "created_time", + "last_modified": "last_modified", + }, + unique_fields=["source", "target"], # 组合唯一性 + ), + ] + + def _initialize_validation_rules(self) -> Dict[str, Any]: + """数据验证已禁用 - 返回空字典""" + return {} + + def connect_mongodb(self) -> bool: + """连接到MongoDB""" + try: + self.mongo_client = MongoClient( + self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10 + ) + + # 测试连接 + self.mongo_client.admin.command("ping") + self.mongo_db = self.mongo_client[self.database_name] + + logger.info(f"成功连接到MongoDB: {self.database_name}") + return True + + except ConnectionFailure as e: + logger.error(f"MongoDB连接失败: {e}") + return False + except Exception as e: + logger.error(f"MongoDB连接异常: {e}") + return False + + def disconnect_mongodb(self): + """断开MongoDB连接""" + if self.mongo_client: + self.mongo_client.close() + logger.info("MongoDB连接已关闭") + + def _get_nested_value(self, document: Dict[str, Any], field_path: str) -> Any: + """获取嵌套字段的值""" + if "." not in field_path: + return document.get(field_path) + + parts = field_path.split(".") + value = document + + for part in parts: + if isinstance(value, dict): + value = value.get(part) + else: + return None + + if value is None: + break + + return value + + def _convert_field_value(self, value: Any, target_field: Field) -> Any: + """根据目标字段类型转换值""" + if value is None: + return None + + field_type = target_field.__class__.__name__ + + try: + if target_field.name == "record_time" and field_type == "DateTimeField": + return datetime.now() + + if field_type in ["CharField", "TextField"]: + if isinstance(value, (list, dict)): + return json.dumps(value, ensure_ascii=False) + return str(value) if value is not None else "" + + elif field_type == "IntegerField": + if isinstance(value, str): + # 处理字符串数字 + clean_value = value.strip() + if clean_value.replace(".", "").replace("-", "").isdigit(): + return int(float(clean_value)) + return 0 + return int(value) if value is not None else 0 + + elif field_type in ["FloatField", "DoubleField"]: + return float(value) if value is not None else 0.0 + + elif field_type == "BooleanField": + if isinstance(value, str): + return value.lower() in ("true", "1", "yes", "on") + return bool(value) + + elif field_type == "DateTimeField": + if isinstance(value, (int, float)): + return datetime.fromtimestamp(value) + elif isinstance(value, str): + try: + # 尝试解析ISO格式日期 + return datetime.fromisoformat(value.replace("Z", "+00:00")) + except ValueError: + try: + # 尝试解析时间戳字符串 + return datetime.fromtimestamp(float(value)) + except ValueError: + return datetime.now() + return datetime.now() + + return value + + except (ValueError, TypeError) as e: + logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}") + return self._get_default_value_for_field(target_field) + + def _get_default_value_for_field(self, field: Field) -> Any: + """获取字段的默认值""" + field_type = field.__class__.__name__ + + if hasattr(field, "default") and field.default is not None: + return field.default + + if field.null: + return None + + # 根据字段类型返回默认值 + if field_type in ["CharField", "TextField"]: + return "" + elif field_type == "IntegerField": + return 0 + elif field_type in ["FloatField", "DoubleField"]: + return 0.0 + elif field_type == "BooleanField": + return False + elif field_type == "DateTimeField": + return datetime.now() + + return None + + def _validate_data(self, collection_name: str, data: Dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool: + """数据验证已禁用 - 始终返回True""" + return True + + def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any): + """保存迁移断点""" + checkpoint = MigrationCheckpoint( + collection_name=collection_name, + processed_count=processed_count, + last_processed_id=last_id, + timestamp=datetime.now(), + ) + + checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" + try: + with open(checkpoint_file, "wb") as f: + pickle.dump(checkpoint, f) + except Exception as e: + logger.warning(f"保存断点失败: {e}") + + def _load_checkpoint(self, collection_name: str) -> Optional[MigrationCheckpoint]: + """加载迁移断点""" + checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" + if not checkpoint_file.exists(): + return None + + try: + with open(checkpoint_file, "rb") as f: + return pickle.load(f) + except Exception as e: + logger.warning(f"加载断点失败: {e}") + return None + + def _batch_insert(self, model: Type[Model], data_list: List[Dict[str, Any]]) -> int: + """批量插入数据""" + if not data_list: + return 0 + + success_count = 0 + try: + with db.atomic(): + # 分批插入,避免SQL语句过长 + batch_size = 100 + for i in range(0, len(data_list), batch_size): + batch = data_list[i : i + batch_size] + model.insert_many(batch).execute() + success_count += len(batch) + except Exception as e: + logger.error(f"批量插入失败: {e}") + # 如果批量插入失败,尝试逐个插入 + for data in data_list: + try: + model.create(**data) + success_count += 1 + except Exception: + pass # 忽略单个插入失败 + + return success_count + + def _check_duplicate_by_unique_fields( + self, model: Type[Model], data: Dict[str, Any], unique_fields: List[str] + ) -> bool: + """根据唯一字段检查重复""" + if not unique_fields: + return False + + try: + query = model.select() + for field_name in unique_fields: + if field_name in data and data[field_name] is not None: + field_obj = getattr(model, field_name) + query = query.where(field_obj == data[field_name]) + + return query.exists() + except Exception as e: + logger.debug(f"重复检查失败: {e}") + return False + + def _create_model_instance(self, model: Type[Model], data: Dict[str, Any]) -> Optional[Model]: + """使用ORM创建模型实例""" + try: + # 过滤掉不存在的字段 + valid_data = {} + for field_name, value in data.items(): + if hasattr(model, field_name): + valid_data[field_name] = value + else: + logger.debug(f"跳过未知字段: {field_name}") + + # 创建实例 + instance = model.create(**valid_data) + return instance + + except IntegrityError as e: + # 处理唯一约束冲突等完整性错误 + logger.debug(f"完整性约束冲突: {e}") + return None + except Exception as e: + logger.error(f"创建模型实例失败: {e}") + return None + + def migrate_collection(self, config: MigrationConfig) -> MigrationStats: + """迁移单个集合 - 使用优化的批量插入和进度条""" + stats = MigrationStats() + stats.start_time = datetime.now() + + # 检查是否有断点 + checkpoint = self._load_checkpoint(config.mongo_collection) + start_from_id = checkpoint.last_processed_id if checkpoint else None + if checkpoint: + stats.processed_count = checkpoint.processed_count + logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录") + + logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}") + + try: + # 获取MongoDB集合 + mongo_collection = self.mongo_db[config.mongo_collection] + + # 构建查询条件(用于断点恢复) + query = {} + if start_from_id: + query = {"_id": {"$gt": start_from_id}} + + stats.total_documents = mongo_collection.count_documents(query) + + if stats.total_documents == 0: + logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移") + return stats + + logger.info(f"待迁移文档数量: {stats.total_documents}") + + # 创建Rich进度条 + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=self.console, + refresh_per_second=10, + ) as progress: + task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents) + # 批量处理数据 + batch_data = [] + batch_count = 0 + last_processed_id = None + + for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size): + try: + doc_id = mongo_doc.get("_id", "unknown") + last_processed_id = doc_id + + # 构建目标数据 + target_data = {} + for mongo_field, sqlite_field in config.field_mapping.items(): + value = self._get_nested_value(mongo_doc, mongo_field) + + # 获取目标字段对象并转换类型 + if hasattr(config.target_model, sqlite_field): + field_obj = getattr(config.target_model, sqlite_field) + converted_value = self._convert_field_value(value, field_obj) + target_data[sqlite_field] = converted_value + + # 数据验证已禁用 + # if config.enable_validation: + # if not self._validate_data(config.mongo_collection, target_data, doc_id, stats): + # stats.skipped_count += 1 + # continue + + # 重复检查 + if config.skip_duplicates and self._check_duplicate_by_unique_fields( + config.target_model, target_data, config.unique_fields + ): + stats.duplicate_count += 1 + stats.skipped_count += 1 + logger.debug(f"跳过重复记录: {doc_id}") + continue + + # 添加到批量数据 + batch_data.append(target_data) + stats.processed_count += 1 + + # 执行批量插入 + if len(batch_data) >= config.batch_size: + success_count = self._batch_insert(config.target_model, batch_data) + stats.success_count += success_count + stats.batch_insert_count += 1 + + # 保存断点 + self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id) + + batch_data.clear() + batch_count += 1 + + # 更新进度条 + progress.update(task, advance=config.batch_size) + + except Exception as e: + doc_id = mongo_doc.get("_id", "unknown") + stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc) + logger.error(f"处理文档失败 (ID: {doc_id}): {e}") + + # 处理剩余的批量数据 + if batch_data: + success_count = self._batch_insert(config.target_model, batch_data) + stats.success_count += success_count + stats.batch_insert_count += 1 + progress.update(task, advance=len(batch_data)) + + # 完成进度条 + progress.update(task, completed=stats.total_documents) + + stats.end_time = datetime.now() + duration = stats.end_time - stats.start_time + + logger.info( + f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n" + f"总计: {stats.total_documents}, 成功: {stats.success_count}, " + f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n" + f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}" + ) + + # 清理断点文件 + checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl" + if checkpoint_file.exists(): + checkpoint_file.unlink() + + except Exception as e: + logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}") + stats.add_error("collection_error", str(e)) + + return stats + + def migrate_all(self) -> Dict[str, MigrationStats]: + """执行所有迁移任务""" + logger.info("开始执行数据库迁移...") + + if not self.connect_mongodb(): + logger.error("无法连接到MongoDB,迁移终止") + return {} + + all_stats = {} + + try: + # 创建总体进度表格 + total_collections = len(self.migration_configs) + self.console.print( + Panel( + f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n" + f"[yellow]总集合数: {total_collections}[/yellow]", + title="迁移开始", + expand=False, + ) + ) + for idx, config in enumerate(self.migration_configs, 1): + self.console.print( + f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]" + ) + stats = self.migrate_collection(config) + all_stats[config.mongo_collection] = stats + + # 显示单个集合的快速统计 + if stats.processed_count > 0: + success_rate = stats.success_count / stats.processed_count * 100 + if success_rate >= 95: + status_emoji = "✅" + status_color = "bright_green" + elif success_rate >= 80: + status_emoji = "⚠️" + status_color = "yellow" + else: + status_emoji = "❌" + status_color = "red" + + self.console.print( + f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} " + f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]" + ) + + # 错误率检查 + if stats.processed_count > 0: + error_rate = stats.error_count / stats.processed_count + if error_rate > 0.1: # 错误率超过10% + self.console.print( + f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} " + f"({stats.error_count}/{stats.processed_count})[/red]" + ) + + finally: + self.disconnect_mongodb() + + self._print_migration_summary(all_stats) + return all_stats + + def _print_migration_summary(self, all_stats: Dict[str, MigrationStats]): + """使用Rich打印美观的迁移汇总信息""" + # 计算总体统计 + total_processed = sum(stats.processed_count for stats in all_stats.values()) + total_success = sum(stats.success_count for stats in all_stats.values()) + total_errors = sum(stats.error_count for stats in all_stats.values()) + total_skipped = sum(stats.skipped_count for stats in all_stats.values()) + total_duplicates = sum(stats.duplicate_count for stats in all_stats.values()) + total_validation_errors = sum(stats.validation_errors for stats in all_stats.values()) + total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values()) + + # 计算总耗时 + total_duration_seconds = 0 + for stats in all_stats.values(): + if stats.start_time and stats.end_time: + duration = stats.end_time - stats.start_time + total_duration_seconds += duration.total_seconds() + + # 创建详细统计表格 + table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta") + table.add_column("集合名称", style="cyan", width=20) + table.add_column("文档总数", justify="right", style="blue") + table.add_column("处理数量", justify="right", style="green") + table.add_column("成功数量", justify="right", style="green") + table.add_column("错误数量", justify="right", style="red") + table.add_column("跳过数量", justify="right", style="yellow") + table.add_column("重复数量", justify="right", style="bright_yellow") + table.add_column("验证错误", justify="right", style="red") + table.add_column("批次数", justify="right", style="purple") + table.add_column("成功率", justify="right", style="bright_green") + table.add_column("耗时(秒)", justify="right", style="blue") + + for collection_name, stats in all_stats.items(): + success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0 + duration = 0 + if stats.start_time and stats.end_time: + duration = (stats.end_time - stats.start_time).total_seconds() + + # 根据成功率设置颜色 + if success_rate >= 95: + success_rate_style = "[bright_green]" + elif success_rate >= 80: + success_rate_style = "[yellow]" + else: + success_rate_style = "[red]" + + table.add_row( + collection_name, + str(stats.total_documents), + str(stats.processed_count), + str(stats.success_count), + f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0", + f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0", + f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0", + f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0", + str(stats.batch_insert_count), + f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}", + f"{duration:.2f}", + ) + + # 添加总计行 + total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0 + if total_success_rate >= 95: + total_rate_style = "[bright_green]" + elif total_success_rate >= 80: + total_rate_style = "[yellow]" + else: + total_rate_style = "[red]" + + table.add_section() + table.add_row( + "[bold]总计[/bold]", + f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]", + f"[bold]{total_processed}[/bold]", + f"[bold]{total_success}[/bold]", + f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]", + f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]", + f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]" + if total_duplicates > 0 + else "[bold]0[/bold]", + f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]", + f"[bold]{total_batch_inserts}[/bold]", + f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]", + f"[bold]{total_duration_seconds:.2f}[/bold]", + ) + + self.console.print(table) + + # 创建状态面板 + status_items = [] + if total_errors > 0: + status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]") + + if total_validation_errors > 0: + status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]") + + if total_duplicates > 0: + status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]") + + if total_success_rate >= 95: + status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]") + elif total_success_rate >= 80: + status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]") + else: + status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]") + + if status_items: + status_panel = Panel( + "\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow" + ) + self.console.print(status_panel) + + # 性能统计面板 + avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0 + performance_info = ( + f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f} 秒\n" + f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n" + f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作" + ) + + performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green") + self.console.print(performance_panel) + + def add_migration_config(self, config: MigrationConfig): + """添加新的迁移配置""" + self.migration_configs.append(config) + + def migrate_single_collection(self, collection_name: str) -> Optional[MigrationStats]: + """迁移单个指定的集合""" + config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None) + if not config: + logger.error(f"未找到集合 {collection_name} 的迁移配置") + return None + + if not self.connect_mongodb(): + logger.error("无法连接到MongoDB") + return None + + try: + stats = self.migrate_collection(config) + self._print_migration_summary({collection_name: stats}) + return stats + finally: + self.disconnect_mongodb() + + def export_error_report(self, all_stats: Dict[str, MigrationStats], filepath: str): + """导出错误报告""" + error_report = { + "timestamp": datetime.now().isoformat(), + "summary": { + collection: { + "total": stats.total_documents, + "processed": stats.processed_count, + "success": stats.success_count, + "errors": stats.error_count, + "skipped": stats.skipped_count, + "duplicates": stats.duplicate_count, + } + for collection, stats in all_stats.items() + }, + "errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors}, + } + + try: + with open(filepath, "w", encoding="utf-8") as f: + json.dump(error_report, f, ensure_ascii=False, indent=2) + logger.info(f"错误报告已导出到: {filepath}") + except Exception as e: + logger.error(f"导出错误报告失败: {e}") + + +def main(): + """主程序入口""" + migrator = MongoToSQLiteMigrator() + + # 执行迁移 + migration_results = migrator.migrate_all() + + # 导出错误报告(如果有错误) + if any(stats.error_count > 0 for stats in migration_results.values()): + error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + migrator.export_error_report(migration_results, error_report_path) + + logger.info("数据迁移完成!") + + +if __name__ == "__main__": + main() diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py new file mode 100644 index 000000000..42a99133f --- /dev/null +++ b/scripts/raw_data_preprocessor.py @@ -0,0 +1,75 @@ +import os +from pathlib import Path +import sys # 新增系统模块导入 +from src.chat.knowledge.utils.hash import get_sha256 +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from src.common.logger import get_logger + +logger = get_logger("lpmm") +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") +# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") + +def _process_text_file(file_path): + """处理单个文本文件,返回段落列表""" + with open(file_path, "r", encoding="utf-8") as f: + raw = f.read() + + paragraphs = [] + paragraph = "" + for line in raw.split("\n"): + if line.strip() == "": + if paragraph != "": + paragraphs.append(paragraph.strip()) + paragraph = "" + else: + paragraph += line + "\n" + + if paragraph != "": + paragraphs.append(paragraph.strip()) + + return paragraphs + + +def _process_multi_files() -> list: + raw_files = list(Path(RAW_DATA_PATH).glob("*.txt")) + if not raw_files: + logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") + sys.exit(1) + # 处理所有文件 + all_paragraphs = [] + for file in raw_files: + logger.info(f"正在处理文件: {file.name}") + paragraphs = _process_text_file(file) + all_paragraphs.extend(paragraphs) + return all_paragraphs + +def load_raw_data() -> tuple[list[str], list[str]]: + """加载原始数据文件 + + 读取原始数据文件,将原始数据加载到内存中 + + Args: + path: 可选,指定要读取的json文件绝对路径 + + Returns: + - raw_data: 原始数据列表 + - sha256_list: 原始数据的SHA256集合 + """ + raw_data = _process_multi_files() + sha256_list = [] + sha256_set = set() + for item in raw_data: + if not isinstance(item, str): + logger.warning(f"数据类型错误:{item}") + continue + pg_hash = get_sha256(item) + if pg_hash in sha256_set: + logger.warning(f"重复数据:{item}") + continue + sha256_set.add(pg_hash) + sha256_list.append(pg_hash) + raw_data.append(item) + logger.info(f"共读取到{len(raw_data)}条数据") + + return sha256_list, raw_data \ No newline at end of file diff --git a/scripts/run.sh b/scripts/run.sh new file mode 100644 index 000000000..d702323a6 --- /dev/null +++ b/scripts/run.sh @@ -0,0 +1,556 @@ +#!/bin/bash + +# MaiCore & NapCat Adapter一键安装脚本 by Cookie_987 +# 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9 +# 请小心使用任何一键脚本! + +INSTALLER_VERSION="0.0.5-refactor" +LANG=C.UTF-8 + +# 如无法访问GitHub请修改此处镜像地址 +GITHUB_REPO="https://ghfast.top/https://github.com" + +# 颜色输出 +GREEN="\e[32m" +RED="\e[31m" +RESET="\e[0m" + +# 需要的基本软件包 + +declare -A REQUIRED_PACKAGES=( + ["common"]="git sudo python3 curl gnupg" + ["debian"]="python3-venv python3-pip build-essential" + ["ubuntu"]="python3-venv python3-pip build-essential" + ["centos"]="epel-release python3-pip python3-devel gcc gcc-c++ make" + ["arch"]="python-virtualenv python-pip base-devel" +) + +# 默认项目目录 +DEFAULT_INSTALL_DIR="/opt/maicore" + +# 服务名称 +SERVICE_NAME="maicore" +SERVICE_NAME_WEB="maicore-web" +SERVICE_NAME_NBADAPTER="maibot-napcat-adapter" + +IS_INSTALL_NAPCAT=false +IS_INSTALL_DEPENDENCIES=false + +# 检查是否已安装 +check_installed() { + [[ -f /etc/systemd/system/${SERVICE_NAME}.service ]] +} + +# 加载安装信息 +load_install_info() { + if [[ -f /etc/maicore_install.conf ]]; then + source /etc/maicore_install.conf + else + INSTALL_DIR="$DEFAULT_INSTALL_DIR" + BRANCH="refactor" + fi +} + +# 显示管理菜单 +show_menu() { + while true; do + choice=$(whiptail --title "MaiCore管理菜单" --menu "请选择要执行的操作:" 15 60 7 \ + "1" "启动MaiCore" \ + "2" "停止MaiCore" \ + "3" "重启MaiCore" \ + "4" "启动NapCat Adapter" \ + "5" "停止NapCat Adapter" \ + "6" "重启NapCat Adapter" \ + "7" "拉取最新MaiCore仓库" \ + "8" "切换分支" \ + "9" "退出" 3>&1 1>&2 2>&3) + + [[ $? -ne 0 ]] && exit 0 + + case "$choice" in + 1) + systemctl start ${SERVICE_NAME} + whiptail --msgbox "✅MaiCore已启动" 10 60 + ;; + 2) + systemctl stop ${SERVICE_NAME} + whiptail --msgbox "🛑MaiCore已停止" 10 60 + ;; + 3) + systemctl restart ${SERVICE_NAME} + whiptail --msgbox "🔄MaiCore已重启" 10 60 + ;; + 4) + systemctl start ${SERVICE_NAME_NBADAPTER} + whiptail --msgbox "✅NapCat Adapter已启动" 10 60 + ;; + 5) + systemctl stop ${SERVICE_NAME_NBADAPTER} + whiptail --msgbox "🛑NapCat Adapter已停止" 10 60 + ;; + 6) + systemctl restart ${SERVICE_NAME_NBADAPTER} + whiptail --msgbox "🔄NapCat Adapter已重启" 10 60 + ;; + 7) + update_dependencies + ;; + 8) + switch_branch + ;; + 9) + exit 0 + ;; + *) + whiptail --msgbox "无效选项!" 10 60 + ;; + esac + done +} + +# 更新依赖 +update_dependencies() { + whiptail --title "⚠" --msgbox "更新后请阅读教程" 10 60 + systemctl stop ${SERVICE_NAME} + cd "${INSTALL_DIR}/MaiBot" || { + whiptail --msgbox "🚫 无法进入安装目录!" 10 60 + return 1 + } + if ! git pull origin "${BRANCH}"; then + whiptail --msgbox "🚫 代码更新失败!" 10 60 + return 1 + fi + source "${INSTALL_DIR}/venv/bin/activate" + if ! pip install -r requirements.txt; then + whiptail --msgbox "🚫 依赖安装失败!" 10 60 + deactivate + return 1 + fi + deactivate + whiptail --msgbox "✅ 已停止服务并拉取最新仓库提交" 10 60 +} + +# 切换分支 +switch_branch() { + new_branch=$(whiptail --inputbox "请输入要切换的分支名称:" 10 60 "${BRANCH}" 3>&1 1>&2 2>&3) + [[ -z "$new_branch" ]] && { + whiptail --msgbox "🚫 分支名称不能为空!" 10 60 + return 1 + } + + cd "${INSTALL_DIR}/MaiBot" || { + whiptail --msgbox "🚫 无法进入安装目录!" 10 60 + return 1 + } + + if ! git ls-remote --exit-code --heads origin "${new_branch}" >/dev/null 2>&1; then + whiptail --msgbox "🚫 分支 ${new_branch} 不存在!" 10 60 + return 1 + fi + + if ! git checkout "${new_branch}"; then + whiptail --msgbox "🚫 分支切换失败!" 10 60 + return 1 + fi + + if ! git pull origin "${new_branch}"; then + whiptail --msgbox "🚫 代码拉取失败!" 10 60 + return 1 + fi + systemctl stop ${SERVICE_NAME} + source "${INSTALL_DIR}/venv/bin/activate" + pip install -r requirements.txt + deactivate + + sed -i "s/^BRANCH=.*/BRANCH=${new_branch}/" /etc/maicore_install.conf + BRANCH="${new_branch}" + check_eula + whiptail --msgbox "✅ 已停止服务并切换到分支 ${new_branch} !" 10 60 +} + +check_eula() { + # 首先计算当前EULA的MD5值 + current_md5=$(md5sum "${INSTALL_DIR}/MaiBot/EULA.md" | awk '{print $1}') + + # 首先计算当前隐私条款文件的哈希值 + current_md5_privacy=$(md5sum "${INSTALL_DIR}/MaiBot/PRIVACY.md" | awk '{print $1}') + + # 如果当前的md5值为空,则直接返回 + if [[ -z $current_md5 || -z $current_md5_privacy ]]; then + whiptail --msgbox "🚫 未找到使用协议\n 请检查PRIVACY.md和EULA.md是否存在" 10 60 + fi + + # 检查eula.confirmed文件是否存在 + if [[ -f ${INSTALL_DIR}/MaiBot/eula.confirmed ]]; then + # 如果存在则检查其中包含的md5与current_md5是否一致 + confirmed_md5=$(cat ${INSTALL_DIR}/MaiBot/eula.confirmed) + else + confirmed_md5="" + fi + + # 检查privacy.confirmed文件是否存在 + if [[ -f ${INSTALL_DIR}/MaiBot/privacy.confirmed ]]; then + # 如果存在则检查其中包含的md5与current_md5是否一致 + confirmed_md5_privacy=$(cat ${INSTALL_DIR}/MaiBot/privacy.confirmed) + else + confirmed_md5_privacy="" + fi + + # 如果EULA或隐私条款有更新,提示用户重新确认 + if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then + whiptail --title "📜 使用协议更新" --yesno "检测到MaiCore EULA或隐私条款已更新。\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议? \n\n " 12 70 + if [[ $? -eq 0 ]]; then + echo -n $current_md5 > ${INSTALL_DIR}/MaiBot/eula.confirmed + echo -n $current_md5_privacy > ${INSTALL_DIR}/MaiBot/privacy.confirmed + else + exit 1 + fi + fi + +} + +# ----------- 主安装流程 ----------- +run_installation() { + # 1/6: 检测是否安装 whiptail + if ! command -v whiptail &>/dev/null; then + echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}" + + if command -v apt-get &>/dev/null; then + apt-get update && apt-get install -y whiptail + elif command -v pacman &>/dev/null; then + pacman -Syu --noconfirm whiptail + elif command -v yum &>/dev/null; then + yum install -y whiptail + else + echo -e "${RED}[Error] 无受支持的包管理器,无法安装 whiptail!${RESET}" + exit 1 + fi + fi + + whiptail --title "ℹ️ 提示" --msgbox "如果您没有特殊需求,请优先使用docker方式部署。" 10 60 + + # 协议确认 + if ! (whiptail --title "ℹ️ [1/6] 使用协议" --yes-button "我同意" --no-button "我拒绝" --yesno "使用MaiCore及此脚本前请先阅读EULA协议及隐私协议\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议?" 12 70); then + exit 1 + fi + + # 欢迎信息 + whiptail --title "[2/6] 欢迎使用MaiCore一键安装脚本 by Cookie987" --msgbox "检测到您未安装MaiCore,将自动进入安装流程,安装完成后再次运行此脚本即可进入管理菜单。\n\n项目处于活跃开发阶段,代码可能随时更改\n文档未完善,有问题可以提交 Issue 或者 Discussion\nQQ机器人存在被限制风险,请自行了解,谨慎使用\n由于持续迭代,可能存在一些已知或未知的bug\n由于开发中,可能消耗较多token\n\n本脚本可能更新不及时,如遇到bug请优先尝试手动部署以确定是否为脚本问题" 17 60 + + # 系统检查 + check_system() { + if [[ "$(id -u)" -ne 0 ]]; then + whiptail --title "🚫 权限不足" --msgbox "请使用 root 用户运行此脚本!\n执行方式: sudo bash $0" 10 60 + exit 1 + fi + + if [[ -f /etc/os-release ]]; then + source /etc/os-release + if [[ "$ID" == "debian" && "$VERSION_ID" == "12" ]]; then + return + elif [[ "$ID" == "ubuntu" && "$VERSION_ID" == "24.10" ]]; then + return + elif [[ "$ID" == "centos" && "$VERSION_ID" == "9" ]]; then + return + elif [[ "$ID" == "arch" ]]; then + whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法,将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60 + return + else + whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9!\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60 + exit 1 + fi + else + whiptail --title "⚠️ 无法检测系统" --msgbox "无法识别系统版本,安装已终止。" 10 60 + exit 1 + fi + } + check_system + + # 设置包管理器 + case "$ID" in + debian|ubuntu) + PKG_MANAGER="apt" + ;; + centos) + PKG_MANAGER="yum" + ;; + arch) + # 添加arch包管理器 + PKG_MANAGER="pacman" + ;; + esac + + # 检查NapCat + check_napcat() { + if command -v napcat &>/dev/null; then + NAPCAT_INSTALLED=true + else + NAPCAT_INSTALLED=false + fi + } + check_napcat + + # 安装必要软件包 + install_packages() { + missing_packages=() + # 检查 common 及当前系统专属依赖 + for package in ${REQUIRED_PACKAGES["common"]} ${REQUIRED_PACKAGES["$ID"]}; do + case "$PKG_MANAGER" in + apt) + dpkg -s "$package" &>/dev/null || missing_packages+=("$package") + ;; + yum) + rpm -q "$package" &>/dev/null || missing_packages+=("$package") + ;; + pacman) + pacman -Qi "$package" &>/dev/null || missing_packages+=("$package") + ;; + esac + done + + if [[ ${#missing_packages[@]} -gt 0 ]]; then + whiptail --title "📦 [3/6] 依赖检查" --yesno "以下软件包缺失:\n${missing_packages[*]}\n\n是否自动安装?" 10 60 + if [[ $? -eq 0 ]]; then + IS_INSTALL_DEPENDENCIES=true + else + whiptail --title "⚠️ 注意" --yesno "未安装某些依赖,可能影响运行!\n是否继续?" 10 60 || exit 1 + fi + fi + } + install_packages + + # 安装NapCat + install_napcat() { + [[ $NAPCAT_INSTALLED == true ]] && return + whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat,是否安装?\n如果您想使用远程NapCat,请跳过此步。" 10 60 && { + IS_INSTALL_NAPCAT=true + } + } + + # 仅在非Arch系统上安装NapCat + [[ "$ID" != "arch" ]] && install_napcat + + # Python版本检查 + check_python() { + PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') + if ! python3 -c "import sys; exit(0) if sys.version_info >= (3,10) else exit(1)"; then + whiptail --title "⚠️ [4/6] Python 版本过低" --msgbox "检测到 Python 版本为 $PYTHON_VERSION,需要 3.10 或以上!\n请升级 Python 后重新运行本脚本。" 10 60 + exit 1 + fi + } + + # 如果没安装python则不检查python版本 + if command -v python3 &>/dev/null; then + check_python + fi + + + # 选择分支 + choose_branch() { + BRANCH=$(whiptail --title "🔀 选择分支" --radiolist "请选择要安装的分支:" 15 60 4 \ + "main" "稳定版本(推荐)" ON \ + "dev" "开发版(不知道什么意思就别选)" OFF \ + "classical" "经典版(0.6.0以前的版本)" OFF \ + "custom" "自定义分支" OFF 3>&1 1>&2 2>&3) + RETVAL=$? + if [ $RETVAL -ne 0 ]; then + whiptail --msgbox "🚫 操作取消!" 10 60 + exit 1 + fi + + if [[ "$BRANCH" == "custom" ]]; then + BRANCH=$(whiptail --title "🔀 自定义分支" --inputbox "请输入自定义分支名称:" 10 60 "refactor" 3>&1 1>&2 2>&3) + RETVAL=$? + if [ $RETVAL -ne 0 ]; then + whiptail --msgbox "🚫 输入取消!" 10 60 + exit 1 + fi + if [[ -z "$BRANCH" ]]; then + whiptail --msgbox "🚫 分支名称不能为空!" 10 60 + exit 1 + fi + fi + } + choose_branch + + # 选择安装路径 + choose_install_dir() { + INSTALL_DIR=$(whiptail --title "📂 [6/6] 选择安装路径" --inputbox "请输入MaiCore的安装目录:" 10 60 "$DEFAULT_INSTALL_DIR" 3>&1 1>&2 2>&3) + [[ -z "$INSTALL_DIR" ]] && { + whiptail --title "⚠️ 取消输入" --yesno "未输入安装路径,是否退出安装?" 10 60 && exit 1 + INSTALL_DIR="$DEFAULT_INSTALL_DIR" + } + } + choose_install_dir + + # 确认安装 + confirm_install() { + local confirm_msg="请确认以下更改:\n\n" + confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n" + confirm_msg+="🔀 分支: $BRANCH\n" + [[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n" + [[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n" + + [[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n" + confirm_msg+="\n注意:本脚本默认使用ghfast.top为GitHub进行加速,如不想使用请手动修改脚本开头的GITHUB_REPO变量。" + + whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 20 60 || exit 1 + } + confirm_install + + # 开始安装 + echo -e "${GREEN}安装${missing_packages[@]}...${RESET}" + + if [[ $IS_INSTALL_DEPENDENCIES == true ]]; then + case "$PKG_MANAGER" in + apt) + apt update && apt install -y "${missing_packages[@]}" + ;; + yum) + yum install -y "${missing_packages[@]}" --nobest + ;; + pacman) + pacman -S --noconfirm "${missing_packages[@]}" + ;; + esac + fi + + if [[ $IS_INSTALL_NAPCAT == true ]]; then + echo -e "${GREEN}安装 NapCat...${RESET}" + curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n + fi + + echo -e "${GREEN}创建安装目录...${RESET}" + mkdir -p "$INSTALL_DIR" + cd "$INSTALL_DIR" || exit 1 + + echo -e "${GREEN}设置Python虚拟环境...${RESET}" + python3 -m venv venv + source venv/bin/activate + + echo -e "${GREEN}克隆MaiCore仓库...${RESET}" + git clone -b "$BRANCH" "$GITHUB_REPO/MaiM-with-u/MaiBot" MaiBot || { + echo -e "${RED}克隆MaiCore仓库失败!${RESET}" + exit 1 + } + + echo -e "${GREEN}克隆 maim_message 包仓库...${RESET}" + git clone $GITHUB_REPO/MaiM-with-u/maim_message.git || { + echo -e "${RED}克隆 maim_message 包仓库失败!${RESET}" + exit 1 + } + + echo -e "${GREEN}克隆 nonebot-plugin-maibot-adapters 仓库...${RESET}" + git clone $GITHUB_REPO/MaiM-with-u/MaiBot-Napcat-Adapter.git || { + echo -e "${RED}克隆 MaiBot-Napcat-Adapter.git 仓库失败!${RESET}" + exit 1 + } + + + echo -e "${GREEN}安装Python依赖...${RESET}" + pip install -r MaiBot/requirements.txt + cd MaiBot + pip install uv + uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt + cd .. + + echo -e "${GREEN}安装maim_message依赖...${RESET}" + cd maim_message + uv pip install -i https://mirrors.aliyun.com/pypi/simple -e . + cd .. + + echo -e "${GREEN}部署MaiBot Napcat Adapter...${RESET}" + cd MaiBot-Napcat-Adapter + uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt + cd .. + + echo -e "${GREEN}同意协议...${RESET}" + + # 首先计算当前EULA的MD5值 + current_md5=$(md5sum "MaiBot/EULA.md" | awk '{print $1}') + + # 首先计算当前隐私条款文件的哈希值 + current_md5_privacy=$(md5sum "MaiBot/PRIVACY.md" | awk '{print $1}') + + echo -n $current_md5 > MaiBot/eula.confirmed + echo -n $current_md5_privacy > MaiBot/privacy.confirmed + + echo -e "${GREEN}创建系统服务...${RESET}" + cat > /etc/systemd/system/${SERVICE_NAME}.service < /etc/systemd/system/${SERVICE_NAME_WEB}.service < /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service < /etc/maicore_install.conf + echo "INSTALL_DIR=${INSTALL_DIR}" >> /etc/maicore_install.conf + echo "BRANCH=${BRANCH}" >> /etc/maicore_install.conf + + whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成!\n已创建系统服务:${SERVICE_NAME}、${SERVICE_NAME_WEB}、${SERVICE_NAME_NBADAPTER}\n\n使用以下命令管理服务:\n启动服务:systemctl start ${SERVICE_NAME}\n查看状态:systemctl status ${SERVICE_NAME}" 14 60 +} + +# ----------- 主执行流程 ----------- +# 检查root权限 +[[ $(id -u) -ne 0 ]] && { + echo -e "${RED}请使用root用户运行此脚本!${RESET}" + exit 1 +} + +# 如果已安装显示菜单,并检查协议是否更新 +if check_installed; then + load_install_info + check_eula + show_menu +else + run_installation + # 安装完成后询问是否启动 + if whiptail --title "安装完成" --yesno "是否立即启动MaiCore服务?" 10 60; then + systemctl start ${SERVICE_NAME} + whiptail --msgbox "✅ 服务已启动!\n使用 systemctl status ${SERVICE_NAME} 查看状态" 10 60 + fi +fi diff --git a/scripts/run_lpmm.sh b/scripts/run_lpmm.sh new file mode 100644 index 000000000..f3f54610d --- /dev/null +++ b/scripts/run_lpmm.sh @@ -0,0 +1,51 @@ +#!/bin/bash + +# ============================================== +# Environment Initialization +# ============================================== + +# Step 1: Locate project root directory +SCRIPTS_DIR="scripts" +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +PROJECT_ROOT=$(cd "$SCRIPT_DIR/.." && pwd) + +# Step 2: Verify scripts directory exists +if [ ! -d "$PROJECT_ROOT/$SCRIPTS_DIR" ]; then + echo "❌ Error: scripts directory not found in project root" >&2 + echo "Current path: $PROJECT_ROOT" >&2 + exit 1 +fi + +# Step 3: Set up Python environment +export PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" +cd "$PROJECT_ROOT" || { + echo "❌ Failed to cd to project root: $PROJECT_ROOT" >&2 + exit 1 +} + +# Debug info +echo "============================" +echo "Project Root: $PROJECT_ROOT" +echo "Python Path: $PYTHONPATH" +echo "Working Dir: $(pwd)" +echo "============================" + +# ============================================== +# Python Script Execution +# ============================================== + +run_python_script() { + local script_name=$1 + echo "🔄 Running $script_name" + if ! python3 "$SCRIPTS_DIR/$script_name"; then + echo "❌ $script_name failed" >&2 + exit 1 + fi +} + +# Execute scripts in order +run_python_script "raw_data_preprocessor.py" +run_python_script "info_extraction.py" +run_python_script "import_openie.py" + +echo "✅ All scripts completed successfully" \ No newline at end of file diff --git a/scripts/text_length_analysis.py b/scripts/text_length_analysis.py new file mode 100644 index 000000000..2ca596e2f --- /dev/null +++ b/scripts/text_length_analysis.py @@ -0,0 +1,394 @@ +import time +import sys +import os +import re +from typing import Dict, List, Tuple, Optional +from datetime import datetime +# Add project root to Python path +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) +from src.common.database.database_model import Messages, ChatStreams #noqa + + +def contains_emoji_or_image_tags(text: str) -> bool: + """Check if text contains [表情包xxxxx] or [图片xxxxx] tags""" + if not text: + return False + + # 检查是否包含 [表情包] 或 [图片] 标记 + emoji_pattern = r'\[表情包[^\]]*\]' + image_pattern = r'\[图片[^\]]*\]' + + return bool(re.search(emoji_pattern, text) or re.search(image_pattern, text)) + + +def clean_reply_text(text: str) -> str: + """Remove reply references like [回复 xxxx...] from text""" + if not text: + return text + + # 匹配 [回复 xxxx...] 格式的内容 + # 使用非贪婪匹配,匹配到第一个 ] 就停止 + cleaned_text = re.sub(r'\[回复[^\]]*\]', '', text) + + # 去除多余的空白字符 + cleaned_text = cleaned_text.strip() + + return cleaned_text + + +def get_chat_name(chat_id: str) -> str: + """Get chat name from chat_id by querying ChatStreams table directly""" + try: + chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) + if chat_stream is None: + return f"未知聊天 ({chat_id})" + + if chat_stream.group_name: + return f"{chat_stream.group_name} ({chat_id})" + elif chat_stream.user_nickname: + return f"{chat_stream.user_nickname}的私聊 ({chat_id})" + else: + return f"未知聊天 ({chat_id})" + except Exception: + return f"查询失败 ({chat_id})" + + +def format_timestamp(timestamp: float) -> str: + """Format timestamp to readable date string""" + try: + return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, OSError): + return "未知时间" + + +def calculate_text_length_distribution(messages) -> Dict[str, int]: + """Calculate distribution of processed_plain_text length""" + distribution = { + '0': 0, # 空文本 + '1-5': 0, # 极短文本 + '6-10': 0, # 很短文本 + '11-20': 0, # 短文本 + '21-30': 0, # 较短文本 + '31-50': 0, # 中短文本 + '51-70': 0, # 中等文本 + '71-100': 0, # 较长文本 + '101-150': 0, # 长文本 + '151-200': 0, # 很长文本 + '201-300': 0, # 超长文本 + '301-500': 0, # 极长文本 + '501-1000': 0, # 巨长文本 + '1000+': 0 # 超巨长文本 + } + + for msg in messages: + if msg.processed_plain_text is None: + continue + + # 排除包含表情包或图片标记的消息 + if contains_emoji_or_image_tags(msg.processed_plain_text): + continue + + # 清理文本中的回复引用 + cleaned_text = clean_reply_text(msg.processed_plain_text) + length = len(cleaned_text) + + if length == 0: + distribution['0'] += 1 + elif length <= 5: + distribution['1-5'] += 1 + elif length <= 10: + distribution['6-10'] += 1 + elif length <= 20: + distribution['11-20'] += 1 + elif length <= 30: + distribution['21-30'] += 1 + elif length <= 50: + distribution['31-50'] += 1 + elif length <= 70: + distribution['51-70'] += 1 + elif length <= 100: + distribution['71-100'] += 1 + elif length <= 150: + distribution['101-150'] += 1 + elif length <= 200: + distribution['151-200'] += 1 + elif length <= 300: + distribution['201-300'] += 1 + elif length <= 500: + distribution['301-500'] += 1 + elif length <= 1000: + distribution['501-1000'] += 1 + else: + distribution['1000+'] += 1 + + return distribution + + +def get_text_length_stats(messages) -> Dict[str, float]: + """Calculate basic statistics for processed_plain_text length""" + lengths = [] + null_count = 0 + excluded_count = 0 # 被排除的消息数量 + + for msg in messages: + if msg.processed_plain_text is None: + null_count += 1 + elif contains_emoji_or_image_tags(msg.processed_plain_text): + # 排除包含表情包或图片标记的消息 + excluded_count += 1 + else: + # 清理文本中的回复引用 + cleaned_text = clean_reply_text(msg.processed_plain_text) + lengths.append(len(cleaned_text)) + + if not lengths: + return { + 'count': 0, + 'null_count': null_count, + 'excluded_count': excluded_count, + 'min': 0, + 'max': 0, + 'avg': 0, + 'median': 0 + } + + lengths.sort() + count = len(lengths) + + return { + 'count': count, + 'null_count': null_count, + 'excluded_count': excluded_count, + 'min': min(lengths), + 'max': max(lengths), + 'avg': sum(lengths) / count, + 'median': lengths[count // 2] if count % 2 == 1 else (lengths[count // 2 - 1] + lengths[count // 2]) / 2 + } + + +def get_available_chats() -> List[Tuple[str, str, int]]: + """Get all available chats with message counts""" + try: + # 获取所有有消息的chat_id,排除特殊类型消息 + chat_counts = {} + for msg in Messages.select(Messages.chat_id).distinct(): + chat_id = msg.chat_id + count = Messages.select().where( + (Messages.chat_id == chat_id) & + (Messages.is_emoji != 1) & + (Messages.is_picid != 1) & + (Messages.is_command != 1) + ).count() + if count > 0: + chat_counts[chat_id] = count + + # 获取聊天名称 + result = [] + for chat_id, count in chat_counts.items(): + chat_name = get_chat_name(chat_id) + result.append((chat_id, chat_name, count)) + + # 按消息数量排序 + result.sort(key=lambda x: x[2], reverse=True) + return result + except Exception as e: + print(f"获取聊天列表失败: {e}") + return [] + + +def get_time_range_input() -> Tuple[Optional[float], Optional[float]]: + """Get time range input from user""" + print("\n时间范围选择:") + print("1. 最近1天") + print("2. 最近3天") + print("3. 最近7天") + print("4. 最近30天") + print("5. 自定义时间范围") + print("6. 不限制时间") + + choice = input("请选择时间范围 (1-6): ").strip() + + now = time.time() + + if choice == "1": + return now - 24*3600, now + elif choice == "2": + return now - 3*24*3600, now + elif choice == "3": + return now - 7*24*3600, now + elif choice == "4": + return now - 30*24*3600, now + elif choice == "5": + print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") + start_str = input().strip() + print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):") + end_str = input().strip() + + try: + start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp() + end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp() + return start_time, end_time + except ValueError: + print("时间格式错误,将不限制时间范围") + return None, None + else: + return None, None + + +def get_top_longest_messages(messages, top_n: int = 10) -> List[Tuple[str, int, str, str]]: + """Get top N longest messages""" + message_lengths = [] + + for msg in messages: + if msg.processed_plain_text is not None: + # 排除包含表情包或图片标记的消息 + if contains_emoji_or_image_tags(msg.processed_plain_text): + continue + + # 清理文本中的回复引用 + cleaned_text = clean_reply_text(msg.processed_plain_text) + length = len(cleaned_text) + chat_name = get_chat_name(msg.chat_id) + time_str = format_timestamp(msg.time) + # 截取前100个字符作为预览 + preview = cleaned_text[:100] + "..." if len(cleaned_text) > 100 else cleaned_text + message_lengths.append((chat_name, length, time_str, preview)) + + # 按长度排序,取前N个 + message_lengths.sort(key=lambda x: x[1], reverse=True) + return message_lengths[:top_n] + + +def analyze_text_lengths(chat_id: Optional[str] = None, start_time: Optional[float] = None, end_time: Optional[float] = None) -> None: + """Analyze processed_plain_text lengths with optional filters""" + + # 构建查询条件,排除特殊类型的消息 + query = Messages.select().where( + (Messages.is_emoji != 1) & + (Messages.is_picid != 1) & + (Messages.is_command != 1) + ) + + if chat_id: + query = query.where(Messages.chat_id == chat_id) + + if start_time: + query = query.where(Messages.time >= start_time) + + if end_time: + query = query.where(Messages.time <= end_time) + + messages = list(query) + + if not messages: + print("没有找到符合条件的消息") + return + + # 计算统计信息 + distribution = calculate_text_length_distribution(messages) + stats = get_text_length_stats(messages) + top_longest = get_top_longest_messages(messages, 10) + + # 显示结果 + print("\n=== Processed Plain Text 长度分析结果 ===") + print("(已排除表情、图片ID、命令类型消息,已排除[表情包]和[图片]标记消息,已清理回复引用)") + if chat_id: + print(f"聊天: {get_chat_name(chat_id)}") + else: + print("聊天: 全部聊天") + + if start_time and end_time: + print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}") + elif start_time: + print(f"时间范围: {format_timestamp(start_time)} 之后") + elif end_time: + print(f"时间范围: {format_timestamp(end_time)} 之前") + else: + print("时间范围: 不限制") + + print("\n基本统计:") + print(f"总消息数量: {len(messages)}") + print(f"有文本消息数量: {stats['count']}") + print(f"空文本消息数量: {stats['null_count']}") + print(f"被排除的消息数量: {stats['excluded_count']}") + if stats['count'] > 0: + print(f"最短长度: {stats['min']} 字符") + print(f"最长长度: {stats['max']} 字符") + print(f"平均长度: {stats['avg']:.2f} 字符") + print(f"中位数长度: {stats['median']:.2f} 字符") + + print("\n文本长度分布:") + total = stats['count'] + if total > 0: + for range_name, count in distribution.items(): + if count > 0: + percentage = count / total * 100 + print(f"{range_name} 字符: {count} ({percentage:.2f}%)") + + # 显示最长的消息 + if top_longest: + print(f"\n最长的 {len(top_longest)} 条消息:") + for i, (chat_name, length, time_str, preview) in enumerate(top_longest, 1): + print(f"{i}. [{chat_name}] {time_str}") + print(f" 长度: {length} 字符") + print(f" 预览: {preview}") + print() + + +def interactive_menu() -> None: + """Interactive menu for text length analysis""" + + while True: + print("\n" + "="*50) + print("Processed Plain Text 长度分析工具") + print("="*50) + print("1. 分析全部聊天") + print("2. 选择特定聊天分析") + print("q. 退出") + + choice = input("\n请选择分析模式 (1-2, q): ").strip() + + if choice.lower() == 'q': + print("再见!") + break + + chat_id = None + + if choice == "2": + # 显示可用的聊天列表 + chats = get_available_chats() + if not chats: + print("没有找到聊天数据") + continue + + print(f"\n可用的聊天 (共{len(chats)}个):") + for i, (_cid, name, count) in enumerate(chats, 1): + print(f"{i}. {name} ({count}条消息)") + + try: + chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip()) + if 1 <= chat_choice <= len(chats): + chat_id = chats[chat_choice - 1][0] + else: + print("无效选择") + continue + except ValueError: + print("请输入有效数字") + continue + + elif choice != "1": + print("无效选择") + continue + + # 获取时间范围 + start_time, end_time = get_time_range_input() + + # 执行分析 + analyze_text_lengths(chat_id, start_time, end_time) + + input("\n按回车键继续...") + + +if __name__ == "__main__": + interactive_menu() \ No newline at end of file diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index d0f6e7744..c4391e7b3 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -117,30 +117,36 @@ class EmbeddingStore: self.idx2hash = None def _get_embedding(self, s: str) -> List[float]: - """获取字符串的嵌入向量,处理异步调用""" + """获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题""" + # 创建新的事件循环并在完成后立即关闭 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: - # 尝试获取当前事件循环 - asyncio.get_running_loop() - # 如果在事件循环中,使用线程池执行 - import concurrent.futures + # 创建新的LLMRequest实例 + from src.llm_models.utils_model import LLMRequest + from src.config.config import model_config - def run_in_thread(): - return asyncio.run(get_embedding(s)) + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_thread) - result = future.result() - if result is None: - logger.error(f"获取嵌入失败: {s}") - return [] - return result - except RuntimeError: - # 没有运行的事件循环,直接运行 - result = asyncio.run(get_embedding(s)) - if result is None: + # 使用新的事件循环运行异步方法 + embedding, _ = loop.run_until_complete(llm.get_embedding(s)) + + if embedding and len(embedding) > 0: + return embedding + else: logger.error(f"获取嵌入失败: {s}") return [] - return result + + except Exception as e: + logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") + return [] + finally: + # 确保事件循环被正确关闭 + try: + loop.close() + except Exception: + pass def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]: """使用多线程批量获取嵌入向量 @@ -181,8 +187,14 @@ class EmbeddingStore: for i, s in enumerate(chunk_strs): try: - # 直接使用异步函数 - embedding = asyncio.run(llm.get_embedding(s)) + # 在线程中创建独立的事件循环 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + embedding = loop.run_until_complete(llm.get_embedding(s)) + finally: + loop.close() + if embedding and len(embedding) > 0: chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 else: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 0b9ec7798..8ea9a4d50 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -111,6 +111,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: async def get_embedding(text, request_type="embedding") -> Optional[List[float]]: """获取文本的embedding向量""" + # 每次都创建新的LLMRequest实例以避免事件循环冲突 llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) try: embedding, _ = await llm.get_embedding(text) diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 97c345466..eb74b0dfe 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -159,14 +159,23 @@ class ClientRegistry: return decorator - def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient: + def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient: """ 获取注册的API客户端实例 Args: api_provider: APIProvider实例 + force_new: 是否强制创建新实例(用于解决事件循环问题) Returns: BaseClient: 注册的API客户端实例 """ + # 如果强制创建新实例,直接创建不使用缓存 + if force_new: + if client_class := self.client_registry.get(api_provider.client_type): + return client_class(api_provider) + else: + raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") + + # 正常的缓存逻辑 if api_provider.name not in self.client_instance_cache: if client_class := self.client_registry.get(api_provider.client_type): self.client_instance_cache[api_provider.name] = client_class(api_provider) diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 0b4f1e709..bf09f0753 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -380,6 +380,7 @@ class OpenaiClient(BaseClient): base_url=api_provider.base_url, api_key=api_provider.api_key, max_retries=0, + timeout=api_provider.timeout, ) async def get_response( @@ -512,6 +513,11 @@ class OpenaiClient(BaseClient): extra_body=extra_params, ) except APIConnectionError as e: + # 添加详细的错误信息以便调试 + logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}") + logger.error(f"错误类型: {type(e)}") + if hasattr(e, '__cause__') and e.__cause__: + logger.error(f"底层错误: {str(e.__cause__)}") raise NetworkConnectionError() from e except APIStatusError as e: # 重封装APIError为RespNotOkException diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 1c7a56eff..c919f789a 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -416,7 +416,10 @@ class LLMRequest: ) model_info = model_config.get_model_info(least_used_model_name) api_provider = model_config.get_provider(model_info.api_provider) - client = client_registry.get_client_class_instance(api_provider) + + # 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题 + force_new_client = (self.request_type == "embedding") + client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client) logger.debug(f"选择请求模型: {model_info.name}") total_tokens, penalty, usage_penalty = self.model_usage[model_info.name] self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用