From 3e854719ee2b93e0195fbf4031c0ba76bb178df1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=98=A5=E6=B2=B3=E6=99=B4?= Date: Tue, 10 Jun 2025 17:31:05 +0900 Subject: [PATCH] ruff --- requirements.txt | Bin 890 -> 908 bytes scripts/analyze_expression_similarity.py | 113 ++- scripts/analyze_expressions.py | 72 +- scripts/analyze_group_similarity.py | 8 +- scripts/cleanup_expressions.py | 62 +- scripts/find_similar_expression.py | 119 +-- scripts/message_retrieval_script.py | 922 +++++++++--------- src/chat/heart_flow/utils_chat.py | 4 +- src/chat/normal_chat/normal_chat_generator.py | 4 +- .../normal_chat/willing/willing_manager.py | 2 +- src/chat/utils/chat_message_builder.py | 12 +- src/person_info/impression_update_task.py | 4 +- src/person_info/relationship_manager.py | 10 +- 13 files changed, 686 insertions(+), 646 deletions(-) diff --git a/requirements.txt b/requirements.txt index 099dbfc684cd54c9eddff56ec8173546f86f2f6f..4ac814b21434cd97c3d68c793e2236d6a85c6116 100644 GIT binary patch delta 26 hcmeyx*2BJ`idl$@p_n0+A(0`8A)ld$A&-HJ0RUh_1 str: """清理群组名称,只保留中文和英文字符""" - cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name) + cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name) if not cleaned: cleaned = datetime.now().strftime("%Y%m%d") return cleaned + def get_group_name(stream_id: str) -> str: """从数据库中获取群组名称""" conn = sqlite3.connect("data/maibot.db") @@ -43,6 +45,7 @@ def get_group_name(stream_id: str) -> str: return clean_group_name(f"{platform}{stream_id[:8]}") return stream_id + def format_timestamp(timestamp: float) -> str: """将时间戳转换为可读的时间格式""" if not timestamp: @@ -50,132 +53,140 @@ def format_timestamp(timestamp: float) -> str: try: dt = datetime.fromtimestamp(timestamp) return dt.strftime("%Y-%m-%d %H:%M:%S") - except: + except Exception as e: + print(f"时间戳格式化错误: {e}") return "未知" + def load_expressions(chat_id: str) -> List[Dict]: """加载指定群聊的表达方式""" style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") - + style_exprs = [] - + if os.path.exists(style_file): with open(style_file, "r", encoding="utf-8") as f: style_exprs = json.load(f) - + return style_exprs + def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]: """找出每个表达方式最相似的top_k个表达方式""" if not expressions: return {} - + # 分别准备情景和表达方式的文本数据 - situations = [expr['situation'] for expr in expressions] - styles = [expr['style'] for expr in expressions] - + situations = [expr["situation"] for expr in expressions] + styles = [expr["style"] for expr in expressions] + # 使用TF-IDF向量化 vectorizer = TfidfVectorizer() situation_matrix = vectorizer.fit_transform(situations) style_matrix = vectorizer.fit_transform(styles) - + # 计算余弦相似度 situation_similarity = cosine_similarity(situation_matrix) style_similarity = cosine_similarity(style_matrix) - + # 对每个表达方式找出最相似的top_k个 similar_expressions = {} - for i, expr in enumerate(expressions): + for i, _ in enumerate(expressions): # 获取相似度分数 situation_scores = situation_similarity[i] style_scores = style_similarity[i] - + # 获取top_k的索引(排除自己) - situation_indices = np.argsort(situation_scores)[::-1][1:top_k+1] - style_indices = np.argsort(style_scores)[::-1][1:top_k+1] - + situation_indices = np.argsort(situation_scores)[::-1][1 : top_k + 1] + style_indices = np.argsort(style_scores)[::-1][1 : top_k + 1] + similar_situations = [] similar_styles = [] - + # 处理相似情景 for idx in situation_indices: if situation_scores[idx] > 0: # 只保留有相似度的 - similar_situations.append(( - expressions[idx]['situation'], - expressions[idx]['style'], # 添加对应的原始表达 - situation_scores[idx] - )) - + similar_situations.append( + ( + expressions[idx]["situation"], + expressions[idx]["style"], # 添加对应的原始表达 + situation_scores[idx], + ) + ) + # 处理相似表达 for idx in style_indices: if style_scores[idx] > 0: # 只保留有相似度的 - similar_styles.append(( - expressions[idx]['style'], - expressions[idx]['situation'], # 添加对应的原始情景 - style_scores[idx] - )) - + similar_styles.append( + ( + expressions[idx]["style"], + expressions[idx]["situation"], # 添加对应的原始情景 + style_scores[idx], + ) + ) + if similar_situations or similar_styles: - similar_expressions[i] = { - 'situations': similar_situations, - 'styles': similar_styles - } - + similar_expressions[i] = {"situations": similar_situations, "styles": similar_styles} + return similar_expressions + def main(): # 获取所有群聊ID style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) chat_ids = [os.path.basename(d) for d in style_dirs] - + if not chat_ids: print("没有找到任何群聊的表达方式数据") return - + print("可用的群聊:") for i, chat_id in enumerate(chat_ids, 1): group_name = get_group_name(chat_id) print(f"{i}. {group_name}") - + while True: try: choice = int(input("\n请选择要分析的群聊编号 (输入0退出): ")) if choice == 0: break if 1 <= choice <= len(chat_ids): - chat_id = chat_ids[choice-1] + chat_id = chat_ids[choice - 1] break print("无效的选择,请重试") except ValueError: print("请输入有效的数字") - + if choice == 0: return - + # 加载表达方式 style_exprs = load_expressions(chat_id) - + group_name = get_group_name(chat_id) print(f"\n分析群聊 {group_name} 的表达方式:") - + similar_styles = find_similar_expressions(style_exprs) for i, expr in enumerate(style_exprs): if i in similar_styles: print("\n" + "-" * 20) print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}") - - if similar_styles[i]['styles']: + + if similar_styles[i]["styles"]: print("\n\033[33m相似表达:\033[0m") - for similar_style, original_situation, score in similar_styles[i]['styles']: + for similar_style, original_situation, score in similar_styles[i]["styles"]: print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m") - - if similar_styles[i]['situations']: + + if similar_styles[i]["situations"]: print("\n\033[32m相似情景:\033[0m") - for similar_situation, original_style, score in similar_styles[i]['situations']: + for similar_situation, original_style, score in similar_styles[i]["situations"]: print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m") - - print(f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}") + + print( + f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}" + ) print("-" * 20) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/analyze_expressions.py b/scripts/analyze_expressions.py index 87d91fa3b..0cda31a06 100644 --- a/scripts/analyze_expressions.py +++ b/scripts/analyze_expressions.py @@ -6,15 +6,17 @@ from datetime import datetime from typing import Dict, List, Any import sqlite3 + def clean_group_name(name: str) -> str: """清理群组名称,只保留中文和英文字符""" # 提取中文和英文字符 - cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name) + cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name) # 如果清理后为空,使用当前日期 if not cleaned: cleaned = datetime.now().strftime("%Y%m%d") return cleaned + def get_group_name(stream_id: str) -> str: """从数据库中获取群组名称""" conn = sqlite3.connect("data/maibot.db") @@ -42,41 +44,44 @@ def get_group_name(stream_id: str) -> str: return clean_group_name(f"{platform}{stream_id[:8]}") return stream_id + def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]: """加载指定群组的表达方式""" learnt_style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") learnt_grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json") personality_file = os.path.join("data", "expression", "personality", "expressions.json") - + style_expressions = [] grammar_expressions = [] personality_expressions = [] - + if os.path.exists(learnt_style_file): with open(learnt_style_file, "r", encoding="utf-8") as f: style_expressions = json.load(f) - + if os.path.exists(learnt_grammar_file): with open(learnt_grammar_file, "r", encoding="utf-8") as f: grammar_expressions = json.load(f) - + if os.path.exists(personality_file): with open(personality_file, "r", encoding="utf-8") as f: personality_expressions = json.load(f) - + return style_expressions, grammar_expressions, personality_expressions + def format_time(timestamp: float) -> str: """格式化时间戳为可读字符串""" return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") + def write_expressions(f, expressions: List[Dict[str, Any]], title: str): """写入表达方式列表""" if not expressions: f.write(f"{title}:暂无数据\n") f.write("-" * 40 + "\n") return - + f.write(f"{title}:\n") for expr in expressions: count = expr.get("count", 0) @@ -87,103 +92,111 @@ def write_expressions(f, expressions: List[Dict[str, Any]], title: str): f.write(f"最后活跃: {format_time(last_active)}\n") f.write("-" * 40 + "\n") -def write_group_report(group_file: str, group_name: str, chat_id: str, style_exprs: List[Dict[str, Any]], grammar_exprs: List[Dict[str, Any]]): + +def write_group_report( + group_file: str, + group_name: str, + chat_id: str, + style_exprs: List[Dict[str, Any]], + grammar_exprs: List[Dict[str, Any]], +): """写入群组详细报告""" with open(group_file, "w", encoding="utf-8") as gf: gf.write(f"群组: {group_name} (ID: {chat_id})\n") gf.write("=" * 80 + "\n\n") - + # 写入语言风格 gf.write("【语言风格】\n") gf.write("=" * 40 + "\n") write_expressions(gf, style_exprs, "语言风格") gf.write("\n") - + # 写入句法特点 gf.write("【句法特点】\n") gf.write("=" * 40 + "\n") write_expressions(gf, grammar_exprs, "句法特点") + def analyze_expressions(): """分析所有群组的表达方式""" # 获取所有群组ID style_dir = os.path.join("data", "expression", "learnt_style") chat_ids = [d for d in os.listdir(style_dir) if os.path.isdir(os.path.join(style_dir, d))] - + # 创建输出目录 output_dir = "data/expression_analysis" personality_dir = os.path.join(output_dir, "personality") os.makedirs(output_dir, exist_ok=True) os.makedirs(personality_dir, exist_ok=True) - + # 生成时间戳 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - + # 创建总报告 summary_file = os.path.join(output_dir, f"summary_{timestamp}.txt") with open(summary_file, "w", encoding="utf-8") as f: f.write(f"表达方式分析报告 - 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write("=" * 80 + "\n\n") - + # 先处理人格表达 personality_exprs = [] personality_file = os.path.join("data", "expression", "personality", "expressions.json") if os.path.exists(personality_file): with open(personality_file, "r", encoding="utf-8") as pf: personality_exprs = json.load(pf) - + # 保存人格表达总数 total_personality = len(personality_exprs) - + # 排序并取前20条 personality_exprs.sort(key=lambda x: x.get("count", 0), reverse=True) personality_exprs = personality_exprs[:20] - + # 写入人格表达报告 personality_report = os.path.join(personality_dir, f"expressions_{timestamp}.txt") with open(personality_report, "w", encoding="utf-8") as pf: pf.write("【人格表达方式】\n") pf.write("=" * 40 + "\n") write_expressions(pf, personality_exprs, "人格表达") - + # 写入总报告摘要中的人格表达部分 f.write("【人格表达方式】\n") f.write("=" * 40 + "\n") f.write(f"人格表达总数: {total_personality} (显示前20条)\n") f.write(f"详细报告: {personality_report}\n") f.write("-" * 40 + "\n\n") - + # 处理各个群组的表达方式 f.write("【群组表达方式】\n") f.write("=" * 40 + "\n\n") - + for chat_id in chat_ids: style_exprs, grammar_exprs, _ = load_expressions(chat_id) - + # 保存总数 total_style = len(style_exprs) total_grammar = len(grammar_exprs) - + # 分别排序 style_exprs.sort(key=lambda x: x.get("count", 0), reverse=True) grammar_exprs.sort(key=lambda x: x.get("count", 0), reverse=True) - + # 只取前20条 style_exprs = style_exprs[:20] grammar_exprs = grammar_exprs[:20] - + # 获取群组名称 group_name = get_group_name(chat_id) - + # 创建群组子目录(使用清理后的名称) safe_group_name = clean_group_name(group_name) group_dir = os.path.join(output_dir, f"{safe_group_name}_{chat_id}") os.makedirs(group_dir, exist_ok=True) - + # 写入群组详细报告 group_file = os.path.join(group_dir, f"expressions_{timestamp}.txt") write_group_report(group_file, group_name, chat_id, style_exprs, grammar_exprs) - + # 写入总报告摘要 f.write(f"群组: {group_name} (ID: {chat_id})\n") f.write("-" * 40 + "\n") @@ -191,11 +204,12 @@ def analyze_expressions(): f.write(f"句法特点总数: {total_grammar} (显示前20条)\n") f.write(f"详细报告: {group_file}\n") f.write("-" * 40 + "\n\n") - + print("分析报告已生成:") print(f"总报告: {summary_file}") print(f"人格表达报告: {personality_report}") print(f"各群组详细报告位于: {output_dir}") + if __name__ == "__main__": - analyze_expressions() + analyze_expressions() diff --git a/scripts/analyze_group_similarity.py b/scripts/analyze_group_similarity.py index 5775a7121..f1d53ee20 100644 --- a/scripts/analyze_group_similarity.py +++ b/scripts/analyze_group_similarity.py @@ -71,14 +71,14 @@ def analyze_group_similarity(): # 获取所有群组目录 base_dir = Path("data/expression/learnt_style") group_dirs = [d for d in base_dir.iterdir() if d.is_dir()] - + # 加载所有群组的数据并过滤 valid_groups = [] valid_names = [] valid_situations = [] valid_styles = [] valid_combined = [] - + for d in group_dirs: situations, styles, combined, total_count = load_group_data(d) if total_count >= 50: # 只保留数据量大于等于50的群组 @@ -87,11 +87,11 @@ def analyze_group_similarity(): valid_situations.append(" ".join(situations)) valid_styles.append(" ".join(styles)) valid_combined.append(" ".join(combined)) - + if not valid_groups: print("没有找到数据量大于等于50的群组") return - + # 创建TF-IDF向量化器 vectorizer = TfidfVectorizer() diff --git a/scripts/cleanup_expressions.py b/scripts/cleanup_expressions.py index c5e66133a..3d7ba1b55 100644 --- a/scripts/cleanup_expressions.py +++ b/scripts/cleanup_expressions.py @@ -3,117 +3,123 @@ import json import random from typing import List, Dict, Tuple import glob -from datetime import datetime MAX_EXPRESSION_COUNT = 300 # 每个群最多保留的表达方式数量 MIN_COUNT_THRESHOLD = 0.01 # 最小使用次数阈值 + def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]: """加载指定群聊的表达方式""" style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json") - + style_exprs = [] grammar_exprs = [] - + if os.path.exists(style_file): with open(style_file, "r", encoding="utf-8") as f: style_exprs = json.load(f) - + if os.path.exists(grammar_file): with open(grammar_file, "r", encoding="utf-8") as f: grammar_exprs = json.load(f) - + return style_exprs, grammar_exprs + def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[Dict]) -> None: """保存表达方式到文件""" style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json") - + os.makedirs(os.path.dirname(style_file), exist_ok=True) os.makedirs(os.path.dirname(grammar_file), exist_ok=True) - + with open(style_file, "w", encoding="utf-8") as f: json.dump(style_exprs, f, ensure_ascii=False, indent=2) - + with open(grammar_file, "w", encoding="utf-8") as f: json.dump(grammar_exprs, f, ensure_ascii=False, indent=2) + def cleanup_expressions(expressions: List[Dict]) -> List[Dict]: """清理表达方式列表""" if not expressions: return [] - + # 1. 移除使用次数过低的表达方式 expressions = [expr for expr in expressions if expr.get("count", 0) > MIN_COUNT_THRESHOLD] - + # 2. 如果数量超过限制,随机删除多余的 if len(expressions) > MAX_EXPRESSION_COUNT: # 按使用次数排序 expressions.sort(key=lambda x: x.get("count", 0), reverse=True) - + # 保留前50%的高频表达方式 keep_count = MAX_EXPRESSION_COUNT // 2 keep_exprs = expressions[:keep_count] - + # 从剩余的表达方式中随机选择 remaining_exprs = expressions[keep_count:] random.shuffle(remaining_exprs) - keep_exprs.extend(remaining_exprs[:MAX_EXPRESSION_COUNT - keep_count]) - + keep_exprs.extend(remaining_exprs[: MAX_EXPRESSION_COUNT - keep_count]) + expressions = keep_exprs - + return expressions + def main(): # 获取所有群聊ID style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) chat_ids = [os.path.basename(d) for d in style_dirs] - + if not chat_ids: print("没有找到任何群聊的表达方式数据") return - + print(f"开始清理 {len(chat_ids)} 个群聊的表达方式数据...") - + total_style_before = 0 total_style_after = 0 total_grammar_before = 0 total_grammar_after = 0 - + for chat_id in chat_ids: print(f"\n处理群聊 {chat_id}:") - + # 加载表达方式 style_exprs, grammar_exprs = load_expressions(chat_id) - + # 记录清理前的数量 style_count_before = len(style_exprs) grammar_count_before = len(grammar_exprs) total_style_before += style_count_before total_grammar_before += grammar_count_before - + # 清理表达方式 style_exprs = cleanup_expressions(style_exprs) grammar_exprs = cleanup_expressions(grammar_exprs) - + # 记录清理后的数量 style_count_after = len(style_exprs) grammar_count_after = len(grammar_exprs) total_style_after += style_count_after total_grammar_after += grammar_count_after - + # 保存清理后的表达方式 save_expressions(chat_id, style_exprs, grammar_exprs) - + print(f"语言风格: {style_count_before} -> {style_count_after}") print(f"句法特点: {grammar_count_before} -> {grammar_count_after}") - + print("\n清理完成!") print(f"语言风格总数: {total_style_before} -> {total_style_after}") print(f"句法特点总数: {total_grammar_before} -> {total_grammar_after}") - print(f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式") + print( + f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式" + ) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/find_similar_expression.py b/scripts/find_similar_expression.py index 21d34e1a8..23f9e63d9 100644 --- a/scripts/find_similar_expression.py +++ b/scripts/find_similar_expression.py @@ -1,5 +1,6 @@ import os import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import json @@ -15,13 +16,15 @@ import random from src.llm_models.utils_model import LLMRequest from src.config.config import global_config + def clean_group_name(name: str) -> str: """清理群组名称,只保留中文和英文字符""" - cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name) + cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name) if not cleaned: cleaned = datetime.now().strftime("%Y%m%d") return cleaned + def get_group_name(stream_id: str) -> str: """从数据库中获取群组名称""" conn = sqlite3.connect("data/maibot.db") @@ -49,76 +52,79 @@ def get_group_name(stream_id: str) -> str: return clean_group_name(f"{platform}{stream_id[:8]}") return stream_id + def load_expressions(chat_id: str) -> List[Dict]: """加载指定群聊的表达方式""" style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json") - + style_exprs = [] - + if os.path.exists(style_file): with open(style_file, "r", encoding="utf-8") as f: style_exprs = json.load(f) - + # 如果表达方式超过10个,随机选择10个 if len(style_exprs) > 50: style_exprs = random.sample(style_exprs, 50) print(f"\n从 {len(style_exprs)} 个表达方式中随机选择了 10 个进行匹配") - + return style_exprs -def find_similar_expressions_tfidf(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10) -> List[Tuple[str, str, float]]: + +def find_similar_expressions_tfidf( + input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10 +) -> List[Tuple[str, str, float]]: """使用TF-IDF方法找出与输入文本最相似的top_k个表达方式""" if not expressions: return [] - + # 准备文本数据 if mode == "style": - texts = [expr['style'] for expr in expressions] + texts = [expr["style"] for expr in expressions] elif mode == "situation": - texts = [expr['situation'] for expr in expressions] + texts = [expr["situation"] for expr in expressions] else: # both texts = [f"{expr['situation']} {expr['style']}" for expr in expressions] - + texts.append(input_text) # 添加输入文本 - + # 使用TF-IDF向量化 vectorizer = TfidfVectorizer() tfidf_matrix = vectorizer.fit_transform(texts) - + # 计算余弦相似度 similarity_matrix = cosine_similarity(tfidf_matrix) - + # 获取输入文本的相似度分数(最后一行) scores = similarity_matrix[-1][:-1] # 排除与自身的相似度 - + # 获取top_k的索引 top_indices = np.argsort(scores)[::-1][:top_k] - + # 获取相似表达 similar_exprs = [] for idx in top_indices: if scores[idx] > 0: # 只保留有相似度的 - similar_exprs.append(( - expressions[idx]['style'], - expressions[idx]['situation'], - scores[idx] - )) - + similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], scores[idx])) + return similar_exprs -async def find_similar_expressions_embedding(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5) -> List[Tuple[str, str, float]]: + +async def find_similar_expressions_embedding( + input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5 +) -> List[Tuple[str, str, float]]: """使用嵌入模型找出与输入文本最相似的top_k个表达方式""" if not expressions: return [] - + # 准备文本数据 if mode == "style": - texts = [expr['style'] for expr in expressions] + texts = [expr["style"] for expr in expressions] elif mode == "situation": - texts = [expr['situation'] for expr in expressions] + texts = [expr["situation"] for expr in expressions] else: # both texts = [f"{expr['situation']} {expr['style']}" for expr in expressions] - + # 获取嵌入向量 llm_request = LLMRequest(global_config.model.embedding) text_embeddings = [] @@ -126,73 +132,70 @@ async def find_similar_expressions_embedding(input_text: str, expressions: List[ embedding = await llm_request.get_embedding(text) if embedding: text_embeddings.append(embedding) - + input_embedding = await llm_request.get_embedding(input_text) if not input_embedding or not text_embeddings: return [] - + # 计算余弦相似度 text_embeddings = np.array(text_embeddings) similarities = np.dot(text_embeddings, input_embedding) / ( np.linalg.norm(text_embeddings, axis=1) * np.linalg.norm(input_embedding) ) - + # 获取top_k的索引 top_indices = np.argsort(similarities)[::-1][:top_k] - + # 获取相似表达 similar_exprs = [] for idx in top_indices: if similarities[idx] > 0: # 只保留有相似度的 - similar_exprs.append(( - expressions[idx]['style'], - expressions[idx]['situation'], - similarities[idx] - )) - + similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], similarities[idx])) + return similar_exprs + async def main(): # 获取所有群聊ID style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*")) chat_ids = [os.path.basename(d) for d in style_dirs] - + if not chat_ids: print("没有找到任何群聊的表达方式数据") return - + print("可用的群聊:") for i, chat_id in enumerate(chat_ids, 1): group_name = get_group_name(chat_id) print(f"{i}. {group_name}") - + while True: try: choice = int(input("\n请选择要分析的群聊编号 (输入0退出): ")) if choice == 0: break if 1 <= choice <= len(chat_ids): - chat_id = chat_ids[choice-1] + chat_id = chat_ids[choice - 1] break print("无效的选择,请重试") except ValueError: print("请输入有效的数字") - + if choice == 0: return - + # 加载表达方式 style_exprs = load_expressions(chat_id) - + group_name = get_group_name(chat_id) print(f"\n已选择群聊:{group_name}") - + # 选择匹配模式 print("\n请选择匹配模式:") print("1. 匹配表达方式") print("2. 匹配情景") print("3. 两者都考虑") - + while True: try: mode_choice = int(input("\n请选择匹配模式 (1-3): ")) @@ -201,19 +204,15 @@ async def main(): print("无效的选择,请重试") except ValueError: print("请输入有效的数字") - - mode_map = { - 1: "style", - 2: "situation", - 3: "both" - } + + mode_map = {1: "style", 2: "situation", 3: "both"} mode = mode_map[mode_choice] - + # 选择匹配方法 print("\n请选择匹配方法:") print("1. TF-IDF方法") print("2. 嵌入模型方法") - + while True: try: method_choice = int(input("\n请选择匹配方法 (1-2): ")) @@ -222,20 +221,20 @@ async def main(): print("无效的选择,请重试") except ValueError: print("请输入有效的数字") - + while True: input_text = input("\n请输入要匹配的文本(输入q退出): ") - if input_text.lower() == 'q': + if input_text.lower() == "q": break - + if not input_text.strip(): continue - + if method_choice == 1: similar_exprs = find_similar_expressions_tfidf(input_text, style_exprs, mode) else: similar_exprs = await find_similar_expressions_embedding(input_text, style_exprs, mode) - + if similar_exprs: print("\n找到以下相似表达:") for style, situation, score in similar_exprs: @@ -246,6 +245,8 @@ async def main(): else: print("\n没有找到相似的表达方式") + if __name__ == "__main__": import asyncio - asyncio.run(main()) \ No newline at end of file + + asyncio.run(main()) diff --git a/scripts/message_retrieval_script.py b/scripts/message_retrieval_script.py index 25cacea7e..1601e637c 100644 --- a/scripts/message_retrieval_script.py +++ b/scripts/message_retrieval_script.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# ruff: noqa: E402 """ 消息检索脚本 @@ -10,319 +11,415 @@ 5. 应用LLM分析,将结果存储到数据库person_info中 """ -import sys -import os import asyncio import json -import re import random -import time -import math -from datetime import datetime, timedelta +import sys from collections import defaultdict -from typing import Dict, List, Any, Optional +from datetime import datetime, timedelta +from difflib import SequenceMatcher from pathlib import Path +from typing import Dict, List, Any, Optional + +import jieba +from json_repair import repair_json +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity # 添加项目根目录到Python路径 project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) -from src.common.database.database_model import Messages -from src.person_info.person_info import PersonInfoManager -from src.config.config import global_config -from src.common.database.database import db from src.chat.utils.chat_message_builder import build_readable_messages -from src.person_info.person_info import person_info_manager -from src.llm_models.utils_model import LLMRequest -from src.individuality.individuality import individuality -from json_repair import repair_json -from difflib import SequenceMatcher -import jieba -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.metrics.pairwise import cosine_similarity +from src.common.database.database_model import Messages from src.common.logger_manager import get_logger +from src.common.database.database import db +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest +from src.person_info.person_info import PersonInfoManager, person_info_manager logger = get_logger("message_retrieval") + +def get_time_range(time_period: str) -> Optional[float]: + """根据时间段选择获取起始时间戳""" + now = datetime.now() + + if time_period == "all": + return None + elif time_period == "3months": + start_time = now - timedelta(days=90) + elif time_period == "1month": + start_time = now - timedelta(days=30) + elif time_period == "1week": + start_time = now - timedelta(days=7) + else: + raise ValueError(f"不支持的时间段: {time_period}") + + return start_time.timestamp() + + +def get_person_id(platform: str, user_id: str) -> str: + """根据platform和user_id计算person_id""" + return PersonInfoManager.get_person_id(platform, user_id) + + +def split_messages_by_count(messages: List[Dict[str, Any]], count: int = 50) -> List[List[Dict[str, Any]]]: + """将消息按指定数量分段""" + chunks = [] + for i in range(0, len(messages), count): + chunks.append(messages[i : i + count]) + return chunks + + +async def build_name_mapping(messages: List[Dict[str, Any]], target_person_name: str) -> Dict[str, str]: + """构建用户名称映射,和relationship_manager中的逻辑一致""" + name_mapping = {} + current_user = "A" + user_count = 1 + + # 遍历消息,构建映射 + for msg in messages: + await person_info_manager.get_or_create_person( + platform=msg.get("chat_info_platform"), + user_id=msg.get("user_id"), + nickname=msg.get("user_nickname"), + user_cardname=msg.get("user_cardname"), + ) + replace_user_id = msg.get("user_id") + replace_platform = msg.get("chat_info_platform") + replace_person_id = get_person_id(replace_platform, replace_user_id) + replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name") + + # 跳过机器人自己 + if replace_user_id == global_config.bot.qq_account: + name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}" + continue + + # 跳过目标用户 + if replace_person_name == target_person_name: + name_mapping[replace_person_name] = f"{target_person_name}" + continue + + # 其他用户映射 + if replace_person_name not in name_mapping: + if current_user > "Z": + current_user = "A" + user_count += 1 + name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}" + current_user = chr(ord(current_user) + 1) + + return name_mapping + + +def build_focus_readable_messages(messages: List[Dict[str, Any]], target_person_id: str = None) -> str: + """格式化消息,只保留目标用户和bot消息附近的内容,和relationship_manager中的逻辑一致""" + # 找到目标用户和bot的消息索引 + target_indices = [] + for i, msg in enumerate(messages): + user_id = msg.get("user_id") + platform = msg.get("chat_info_platform") + person_id = get_person_id(platform, user_id) + if person_id == target_person_id: + target_indices.append(i) + + if not target_indices: + return "" + + # 获取需要保留的消息索引 + keep_indices = set() + for idx in target_indices: + # 获取前后5条消息的索引 + start_idx = max(0, idx - 5) + end_idx = min(len(messages), idx + 6) + keep_indices.update(range(start_idx, end_idx)) + + # 将索引排序 + keep_indices = sorted(list(keep_indices)) + + # 按顺序构建消息组 + message_groups = [] + current_group = [] + + for i in range(len(messages)): + if i in keep_indices: + current_group.append(messages[i]) + elif current_group: + # 如果当前组不为空,且遇到不保留的消息,则结束当前组 + if current_group: + message_groups.append(current_group) + current_group = [] + + # 添加最后一组 + if current_group: + message_groups.append(current_group) + + # 构建最终的消息文本 + result = [] + for i, group in enumerate(message_groups): + if i > 0: + result.append("...") + group_text = build_readable_messages( + messages=group, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=False + ) + result.append(group_text) + + return "\n".join(result) + + +def tfidf_similarity(s1, s2): + """使用 TF-IDF 和余弦相似度计算两个句子的相似性""" + # 确保输入是字符串类型 + if isinstance(s1, list): + s1 = " ".join(str(x) for x in s1) + if isinstance(s2, list): + s2 = " ".join(str(x) for x in s2) + + # 转换为字符串类型 + s1 = str(s1) + s2 = str(s2) + + # 1. 使用 jieba 进行分词 + s1_words = " ".join(jieba.cut(s1)) + s2_words = " ".join(jieba.cut(s2)) + + # 2. 将两句话放入一个列表中 + corpus = [s1_words, s2_words] + + # 3. 创建 TF-IDF 向量化器并进行计算 + try: + vectorizer = TfidfVectorizer() + tfidf_matrix = vectorizer.fit_transform(corpus) + except ValueError: + # 如果句子完全由停用词组成,或者为空,可能会报错 + return 0.0 + + # 4. 计算余弦相似度 + similarity_matrix = cosine_similarity(tfidf_matrix) + + # 返回 s1 和 s2 的相似度 + return similarity_matrix[0, 1] + + +def sequence_similarity(s1, s2): + """使用 SequenceMatcher 计算两个句子的相似性""" + return SequenceMatcher(None, s1, s2).ratio() + + +def calculate_time_weight(point_time: str, current_time: str) -> float: + """计算基于时间的权重系数""" + try: + point_timestamp = datetime.strptime(point_time, "%Y-%m-%d %H:%M:%S") + current_timestamp = datetime.strptime(current_time, "%Y-%m-%d %H:%M:%S") + time_diff = current_timestamp - point_timestamp + hours_diff = time_diff.total_seconds() / 3600 + + if hours_diff <= 1: # 1小时内 + return 1.0 + elif hours_diff <= 24: # 1-24小时 + # 从1.0快速递减到0.7 + return 1.0 - (hours_diff - 1) * (0.3 / 23) + elif hours_diff <= 24 * 7: # 24小时-7天 + # 从0.7缓慢回升到0.95 + return 0.7 + (hours_diff - 24) * (0.25 / (24 * 6)) + else: # 7-30天 + # 从0.95缓慢递减到0.1 + days_diff = hours_diff / 24 - 7 + return max(0.1, 0.95 - days_diff * (0.85 / 23)) + except Exception as e: + logger.error(f"计算时间权重失败: {e}") + return 0.5 # 发生错误时返回中等权重 + + +def filter_selected_chats( + grouped_messages: Dict[str, List[Dict[str, Any]]], selected_indices: List[int] +) -> Dict[str, List[Dict[str, Any]]]: + """根据用户选择过滤群聊""" + chat_items = list(grouped_messages.items()) + selected_chats = {} + + for idx in selected_indices: + chat_id, messages = chat_items[idx - 1] # 转换为0基索引 + selected_chats[chat_id] = messages + + return selected_chats + + +def get_user_selection(total_count: int) -> List[int]: + """获取用户选择的群聊编号""" + while True: + print(f"\n请选择要分析的群聊 (1-{total_count}):") + print("输入格式:") + print(" 单个: 1") + print(" 多个: 1,3,5") + print(" 范围: 1-3") + print(" 全部: all 或 a") + print(" 退出: quit 或 q") + + user_input = input("请输入选择: ").strip().lower() + + if user_input in ["quit", "q"]: + return [] + + if user_input in ["all", "a"]: + return list(range(1, total_count + 1)) + + try: + selected = [] + + # 处理逗号分隔的输入 + parts = user_input.split(",") + + for part in parts: + part = part.strip() + + if "-" in part: + # 处理范围输入 (如: 1-3) + start, end = part.split("-") + start_num = int(start.strip()) + end_num = int(end.strip()) + + if 1 <= start_num <= total_count and 1 <= end_num <= total_count and start_num <= end_num: + selected.extend(range(start_num, end_num + 1)) + else: + raise ValueError("范围超出有效范围") + else: + # 处理单个数字 + num = int(part) + if 1 <= num <= total_count: + selected.append(num) + else: + raise ValueError("数字超出有效范围") + + # 去重并排序 + selected = sorted(list(set(selected))) + + if selected: + return selected + else: + print("错误: 请输入有效的选择") + + except ValueError as e: + print(f"错误: 输入格式无效 - {e}") + print("请重新输入") + + +def display_chat_list(grouped_messages: Dict[str, List[Dict[str, Any]]]) -> None: + """显示群聊列表""" + print("\n找到以下群聊:") + print("=" * 60) + + for i, (chat_id, messages) in enumerate(grouped_messages.items(), 1): + first_msg = messages[0] + group_name = first_msg.get("chat_info_group_name", "私聊") + group_id = first_msg.get("chat_info_group_id", chat_id) + + # 计算时间范围 + start_time = datetime.fromtimestamp(messages[0]["time"]).strftime("%Y-%m-%d") + end_time = datetime.fromtimestamp(messages[-1]["time"]).strftime("%Y-%m-%d") + + print(f"{i:2d}. {group_name}") + print(f" 群ID: {group_id}") + print(f" 消息数: {len(messages)}") + print(f" 时间范围: {start_time} ~ {end_time}") + print("-" * 60) + + +def check_similarity(text1, text2, tfidf_threshold=0.5, seq_threshold=0.6): + """使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的""" + # 计算两种相似度 + tfidf_sim = tfidf_similarity(text1, text2) + seq_sim = sequence_similarity(text1, text2) + + # 只要其中一种方法达到阈值就认为是相似的 + return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold + + class MessageRetrievalScript: def __init__(self): """初始化脚本""" - self.person_info_manager = PersonInfoManager() self.bot_qq = str(global_config.bot.qq_account) - + # 初始化LLM请求器,和relationship_manager一样 self.relationship_llm = LLMRequest( model=global_config.model.relation, request_type="relationship", ) - - def get_person_id(self, platform: str, user_id: str) -> str: - """根据platform和user_id计算person_id""" - return PersonInfoManager.get_person_id(platform, user_id) - - def get_time_range(self, time_period: str) -> Optional[float]: - """根据时间段选择获取起始时间戳""" - now = datetime.now() - - if time_period == "all": - return None - elif time_period == "3months": - start_time = now - timedelta(days=90) - elif time_period == "1month": - start_time = now - timedelta(days=30) - elif time_period == "1week": - start_time = now - timedelta(days=7) - else: - raise ValueError(f"不支持的时间段: {time_period}") - - return start_time.timestamp() - + def retrieve_messages(self, user_qq: str, time_period: str) -> Dict[str, List[Dict[str, Any]]]: """检索消息""" print(f"开始检索用户 {user_qq} 的消息...") - + # 计算person_id - person_id = self.get_person_id("qq", user_qq) + person_id = get_person_id("qq", user_qq) print(f"用户person_id: {person_id}") - + # 获取时间范围 - start_timestamp = self.get_time_range(time_period) + start_timestamp = get_time_range(time_period) if start_timestamp: print(f"时间范围: {datetime.fromtimestamp(start_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 至今") else: print("时间范围: 全部时间") - + # 构建查询条件 query = Messages.select() - + # 添加用户条件:包含bot消息或目标用户消息 user_condition = ( - (Messages.user_id == self.bot_qq) | # bot的消息 - (Messages.user_id == user_qq) # 目标用户的消息 + (Messages.user_id == self.bot_qq) # bot的消息 + | (Messages.user_id == user_qq) # 目标用户的消息 ) query = query.where(user_condition) - + # 添加时间条件 if start_timestamp: query = query.where(Messages.time >= start_timestamp) - + # 按时间排序 query = query.order_by(Messages.time.asc()) - + print("正在执行数据库查询...") messages = list(query) print(f"查询到 {len(messages)} 条消息") - + # 按chat_id分组 grouped_messages = defaultdict(list) for msg in messages: msg_dict = { - 'message_id': msg.message_id, - 'time': msg.time, - 'datetime': datetime.fromtimestamp(msg.time).strftime('%Y-%m-%d %H:%M:%S'), - 'chat_id': msg.chat_id, - 'user_id': msg.user_id, - 'user_nickname': msg.user_nickname, - 'user_platform': msg.user_platform, - 'processed_plain_text': msg.processed_plain_text, - 'display_message': msg.display_message, - 'chat_info_group_id': msg.chat_info_group_id, - 'chat_info_group_name': msg.chat_info_group_name, - 'chat_info_platform': msg.chat_info_platform, - 'user_cardname': msg.user_cardname, - 'is_bot_message': msg.user_id == self.bot_qq + "message_id": msg.message_id, + "time": msg.time, + "datetime": datetime.fromtimestamp(msg.time).strftime("%Y-%m-%d %H:%M:%S"), + "chat_id": msg.chat_id, + "user_id": msg.user_id, + "user_nickname": msg.user_nickname, + "user_platform": msg.user_platform, + "processed_plain_text": msg.processed_plain_text, + "display_message": msg.display_message, + "chat_info_group_id": msg.chat_info_group_id, + "chat_info_group_name": msg.chat_info_group_name, + "chat_info_platform": msg.chat_info_platform, + "user_cardname": msg.user_cardname, + "is_bot_message": msg.user_id == self.bot_qq, } grouped_messages[msg.chat_id].append(msg_dict) - + print(f"消息分布在 {len(grouped_messages)} 个聊天中") return dict(grouped_messages) - - def split_messages_by_count(self, messages: List[Dict[str, Any]], count: int = 50) -> List[List[Dict[str, Any]]]: - """将消息按指定数量分段""" - chunks = [] - for i in range(0, len(messages), count): - chunks.append(messages[i:i + count]) - return chunks - - async def build_name_mapping(self, messages: List[Dict[str, Any]], target_person_id: str, target_person_name: str) -> Dict[str, str]: - """构建用户名称映射,和relationship_manager中的逻辑一致""" - name_mapping = {} - current_user = "A" - user_count = 1 - - # 遍历消息,构建映射 - for msg in messages: - await person_info_manager.get_or_create_person( - platform=msg.get("chat_info_platform"), - user_id=msg.get("user_id"), - nickname=msg.get("user_nickname"), - user_cardname=msg.get("user_cardname"), - ) - replace_user_id = msg.get("user_id") - replace_platform = msg.get("chat_info_platform") - replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id) - replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name") - - # 跳过机器人自己 - if replace_user_id == global_config.bot.qq_account: - name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}" - continue - - # 跳过目标用户 - if replace_person_name == target_person_name: - name_mapping[replace_person_name] = f"{target_person_name}" - continue - - # 其他用户映射 - if replace_person_name not in name_mapping: - if current_user > 'Z': - current_user = 'A' - user_count += 1 - name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}" - current_user = chr(ord(current_user) + 1) - - return name_mapping - - def build_focus_readable_messages(self, messages: List[Dict[str, Any]], target_person_id: str = None) -> str: - """格式化消息,只保留目标用户和bot消息附近的内容,和relationship_manager中的逻辑一致""" - # 找到目标用户和bot的消息索引 - target_indices = [] - for i, msg in enumerate(messages): - user_id = msg.get("user_id") - platform = msg.get("chat_info_platform") - person_id = person_info_manager.get_person_id(platform, user_id) - if person_id == target_person_id: - target_indices.append(i) - - if not target_indices: - return "" - - # 获取需要保留的消息索引 - keep_indices = set() - for idx in target_indices: - # 获取前后5条消息的索引 - start_idx = max(0, idx - 5) - end_idx = min(len(messages), idx + 6) - keep_indices.update(range(start_idx, end_idx)) - - # 将索引排序 - keep_indices = sorted(list(keep_indices)) - - # 按顺序构建消息组 - message_groups = [] - current_group = [] - - for i in range(len(messages)): - if i in keep_indices: - current_group.append(messages[i]) - elif current_group: - # 如果当前组不为空,且遇到不保留的消息,则结束当前组 - if current_group: - message_groups.append(current_group) - current_group = [] - - # 添加最后一组 - if current_group: - message_groups.append(current_group) - - # 构建最终的消息文本 - result = [] - for i, group in enumerate(message_groups): - if i > 0: - result.append("...") - group_text = build_readable_messages( - messages=group, - replace_bot_name=True, - timestamp_mode="normal_no_YMD", - truncate=False - ) - result.append(group_text) - - return "\n".join(result) - + # 添加相似度检查方法,和relationship_manager一致 - def tfidf_similarity(self, s1, s2): - """使用 TF-IDF 和余弦相似度计算两个句子的相似性""" - # 确保输入是字符串类型 - if isinstance(s1, list): - s1 = " ".join(str(x) for x in s1) - if isinstance(s2, list): - s2 = " ".join(str(x) for x in s2) - - # 转换为字符串类型 - s1 = str(s1) - s2 = str(s2) - - # 1. 使用 jieba 进行分词 - s1_words = " ".join(jieba.cut(s1)) - s2_words = " ".join(jieba.cut(s2)) - - # 2. 将两句话放入一个列表中 - corpus = [s1_words, s2_words] - - # 3. 创建 TF-IDF 向量化器并进行计算 - try: - vectorizer = TfidfVectorizer() - tfidf_matrix = vectorizer.fit_transform(corpus) - except ValueError: - # 如果句子完全由停用词组成,或者为空,可能会报错 - return 0.0 - # 4. 计算余弦相似度 - similarity_matrix = cosine_similarity(tfidf_matrix) - - # 返回 s1 和 s2 的相似度 - return similarity_matrix[0, 1] - - def sequence_similarity(self, s1, s2): - """使用 SequenceMatcher 计算两个句子的相似性""" - return SequenceMatcher(None, s1, s2).ratio() - - def check_similarity(self, text1, text2, tfidf_threshold=0.5, seq_threshold=0.6): - """使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的""" - # 计算两种相似度 - tfidf_sim = self.tfidf_similarity(text1, text2) - seq_sim = self.sequence_similarity(text1, text2) - - # 只要其中一种方法达到阈值就认为是相似的 - return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold - - def calculate_time_weight(self, point_time: str, current_time: str) -> float: - """计算基于时间的权重系数""" - try: - point_timestamp = datetime.strptime(point_time, "%Y-%m-%d %H:%M:%S") - current_timestamp = datetime.strptime(current_time, "%Y-%m-%d %H:%M:%S") - time_diff = current_timestamp - point_timestamp - hours_diff = time_diff.total_seconds() / 3600 - - if hours_diff <= 1: # 1小时内 - return 1.0 - elif hours_diff <= 24: # 1-24小时 - # 从1.0快速递减到0.7 - return 1.0 - (hours_diff - 1) * (0.3 / 23) - elif hours_diff <= 24 * 7: # 24小时-7天 - # 从0.7缓慢回升到0.95 - return 0.7 + (hours_diff - 24) * (0.25 / (24 * 6)) - else: # 7-30天 - # 从0.95缓慢递减到0.1 - days_diff = hours_diff / 24 - 7 - return max(0.1, 0.95 - days_diff * (0.85 / 23)) - except Exception as e: - logger.error(f"计算时间权重失败: {e}") - return 0.5 # 发生错误时返回中等权重 - async def update_person_impression_from_segment(self, person_id: str, readable_messages: str, segment_time: float): """从消息段落更新用户印象,使用和relationship_manager相同的流程""" person_name = await person_info_manager.get_value(person_id, "person_name") nickname = await person_info_manager.get_value(person_id, "nickname") - + if not person_name: logger.warning(f"无法获取用户 {person_id} 的person_name") return - + alias_str = ", ".join(global_config.bot.alias_names) current_time = datetime.fromtimestamp(segment_time).strftime("%Y-%m-%d %H:%M:%S") - + prompt = f""" 你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。 请不要混淆你自己和{global_config.bot.nickname}和{person_name}。 @@ -357,17 +454,17 @@ class MessageRetrievalScript: "weight": 0 }} """ - + # 调用LLM生成印象 points, _ = await self.relationship_llm.generate_response_async(prompt=prompt) points = points.strip() - + logger.info(f"LLM分析结果: {points[:200]}...") - + if not points: logger.warning(f"未能从LLM获取 {person_name} 的新印象") return - + # 解析JSON并转换为元组列表 try: points = repair_json(points) @@ -388,11 +485,11 @@ class MessageRetrievalScript: except (KeyError, TypeError) as e: logger.error(f"处理points数据失败: {e}, points: {points}") return - + if not points_list: logger.info(f"用户 {person_name} 的消息段落没有产生新的记忆点") return - + # 获取现有points current_points = await person_info_manager.get_value(person_id, "points") or [] if isinstance(current_points, str): @@ -403,19 +500,19 @@ class MessageRetrievalScript: current_points = [] elif not isinstance(current_points, list): current_points = [] - + # 将新记录添加到现有记录中 for new_point in points_list: similar_points = [] similar_indices = [] - + # 在现有points中查找相似的点 for i, existing_point in enumerate(current_points): # 使用组合的相似度检查方法 - if self.check_similarity(new_point[0], existing_point[0]): + if check_similarity(new_point[0], existing_point[0]): similar_points.append(existing_point) similar_indices.append(i) - + if similar_points: # 合并相似的点 all_points = [new_point] + similar_points @@ -425,14 +522,14 @@ class MessageRetrievalScript: total_weight = sum(p[1] for p in all_points) # 使用最长的描述 longest_desc = max(all_points, key=lambda x: len(x[0]))[0] - + # 创建合并后的点 merged_point = (longest_desc, total_weight, latest_time) - + # 从现有points中移除已合并的点 for idx in sorted(similar_indices, reverse=True): current_points.pop(idx) - + # 添加合并后的点 current_points.append(merged_point) logger.info(f"合并相似记忆点: {longest_desc[:50]}...") @@ -453,29 +550,29 @@ class MessageRetrievalScript: forgotten_points = [] elif not isinstance(forgotten_points, list): forgotten_points = [] - + # 计算当前时间 current_time_str = datetime.fromtimestamp(segment_time).strftime("%Y-%m-%d %H:%M:%S") - + # 计算每个点的最终权重(原始权重 * 时间权重) weighted_points = [] for point in current_points: - time_weight = self.calculate_time_weight(point[2], current_time_str) + time_weight = calculate_time_weight(point[2], current_time_str) final_weight = point[1] * time_weight weighted_points.append((point, final_weight)) - + # 计算总权重 total_weight = sum(w for _, w in weighted_points) - + # 按权重随机选择要保留的点 remaining_points = [] points_to_move = [] - + # 对每个点进行随机选择 for point, weight in weighted_points: # 计算保留概率(权重越高越可能保留) keep_probability = weight / total_weight if total_weight > 0 else 0.5 - + if len(remaining_points) < 10: # 如果还没达到10条,直接保留 remaining_points.append(point) @@ -489,29 +586,28 @@ class MessageRetrievalScript: else: # 不保留这个点 points_to_move.append(point) - + # 更新points和forgotten_points current_points = remaining_points forgotten_points.extend(points_to_move) logger.info(f"将 {len(points_to_move)} 个记忆点移动到forgotten_points") - + # 检查forgotten_points是否达到5条 if len(forgotten_points) >= 10: print(f"forgotten_points: {forgotten_points}") # 构建压缩总结提示词 alias_str = ", ".join(global_config.bot.alias_names) - + # 按时间排序forgotten_points forgotten_points.sort(key=lambda x: x[2]) - + # 构建points文本 - points_text = "\n".join([ - f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" - for point in forgotten_points - ]) - + points_text = "\n".join( + [f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" for point in forgotten_points] + ) + impression = await person_info_manager.get_value(person_id, "impression") or "" - + compress_prompt = f""" 你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。 请不要混淆你自己和{global_config.bot.nickname}和{person_name}。 @@ -532,109 +628,113 @@ class MessageRetrievalScript: """ # 调用LLM生成压缩总结 compressed_summary, _ = await self.relationship_llm.generate_response_async(prompt=compress_prompt) - + current_time_formatted = datetime.fromtimestamp(segment_time).strftime("%Y-%m-%d %H:%M:%S") compressed_summary = f"截至{current_time_formatted},你对{person_name}的了解:{compressed_summary}" - + await person_info_manager.update_one_field(person_id, "impression", compressed_summary) logger.info(f"更新了用户 {person_name} 的总体印象") - + # 清空forgotten_points forgotten_points = [] # 更新数据库 - await person_info_manager.update_one_field(person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)) - + await person_info_manager.update_one_field( + person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None) + ) + # 更新数据库 - await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)) + await person_info_manager.update_one_field( + person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None) + ) know_times = await person_info_manager.get_value(person_id, "know_times") or 0 await person_info_manager.update_one_field(person_id, "know_times", know_times + 1) await person_info_manager.update_one_field(person_id, "last_know", segment_time) logger.info(f"印象更新完成 for {person_name},新增 {len(points_list)} 个记忆点") - - async def process_segments_and_update_impression(self, user_qq: str, grouped_messages: Dict[str, List[Dict[str, Any]]]): + + async def process_segments_and_update_impression( + self, user_qq: str, grouped_messages: Dict[str, List[Dict[str, Any]]] + ): """处理分段消息并更新用户印象到数据库""" # 获取目标用户信息 - target_person_id = self.get_person_id("qq", user_qq) + target_person_id = get_person_id("qq", user_qq) target_person_name = await person_info_manager.get_value(target_person_id, "person_name") - target_nickname = await person_info_manager.get_value(target_person_id, "nickname") - + if not target_person_name: target_person_name = f"用户{user_qq}" - if not target_nickname: - target_nickname = f"用户{user_qq}" - + print(f"\n开始分析用户 {target_person_name} (QQ: {user_qq}) 的消息...") - + total_segments_processed = 0 - + # 收集所有分段并按时间排序 all_segments = [] - + # 为每个chat_id处理消息,收集所有分段 for chat_id, messages in grouped_messages.items(): first_msg = messages[0] - group_name = first_msg.get('chat_info_group_name', '私聊') - + group_name = first_msg.get("chat_info_group_name", "私聊") + print(f"准备聊天: {group_name} (共{len(messages)}条消息)") - + # 将消息按50条分段 - message_chunks = self.split_messages_by_count(messages, 50) - + message_chunks = split_messages_by_count(messages, 50) + for i, chunk in enumerate(message_chunks): # 将分段信息添加到列表中,包含分段时间用于排序 - segment_time = chunk[-1]['time'] - all_segments.append({ - 'chunk': chunk, - 'chat_id': chat_id, - 'group_name': group_name, - 'segment_index': i + 1, - 'total_segments': len(message_chunks), - 'segment_time': segment_time - }) - + segment_time = chunk[-1]["time"] + all_segments.append( + { + "chunk": chunk, + "chat_id": chat_id, + "group_name": group_name, + "segment_index": i + 1, + "total_segments": len(message_chunks), + "segment_time": segment_time, + } + ) + # 按时间排序所有分段 - all_segments.sort(key=lambda x: x['segment_time']) - + all_segments.sort(key=lambda x: x["segment_time"]) + print(f"\n按时间顺序处理 {len(all_segments)} 个分段:") - + # 按时间顺序处理所有分段 for segment_idx, segment_info in enumerate(all_segments, 1): - chunk = segment_info['chunk'] - group_name = segment_info['group_name'] - segment_index = segment_info['segment_index'] - total_segments = segment_info['total_segments'] - segment_time = segment_info['segment_time'] - - segment_time_str = datetime.fromtimestamp(segment_time).strftime('%Y-%m-%d %H:%M:%S') - print(f" [{segment_idx}/{len(all_segments)}] {group_name} 第{segment_index}/{total_segments}段 ({segment_time_str}) (共{len(chunk)}条)") - - # 构建名称映射 - name_mapping = await self.build_name_mapping(chunk, target_person_id, target_person_name) - - # 构建可读消息 - readable_messages = self.build_focus_readable_messages( - messages=chunk, - target_person_id=target_person_id + chunk = segment_info["chunk"] + group_name = segment_info["group_name"] + segment_index = segment_info["segment_index"] + total_segments = segment_info["total_segments"] + segment_time = segment_info["segment_time"] + + segment_time_str = datetime.fromtimestamp(segment_time).strftime("%Y-%m-%d %H:%M:%S") + print( + f" [{segment_idx}/{len(all_segments)}] {group_name} 第{segment_index}/{total_segments}段 ({segment_time_str}) (共{len(chunk)}条)" ) - + + # 构建名称映射 + name_mapping = await build_name_mapping(chunk, target_person_name) + + # 构建可读消息 + readable_messages = build_focus_readable_messages(messages=chunk, target_person_id=target_person_id) + if not readable_messages: - print(f" 跳过:该段落没有目标用户的消息") + print(" 跳过:该段落没有目标用户的消息") continue - + # 应用名称映射 for original_name, mapped_name in name_mapping.items(): readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}") - + # 更新用户印象 try: await self.update_person_impression_from_segment(target_person_id, readable_messages, segment_time) total_segments_processed += 1 except Exception as e: logger.error(f"处理段落时出错: {e}") - print(f" 错误:处理该段落时出现异常") - + print(" 错误:处理该段落时出现异常") + # 获取最终统计 final_points = await person_info_manager.get_value(target_person_id, "points") or [] if isinstance(final_points, str): @@ -642,139 +742,45 @@ class MessageRetrievalScript: final_points = json.loads(final_points) except json.JSONDecodeError: final_points = [] - + final_impression = await person_info_manager.get_value(target_person_id, "impression") or "" - - print(f"\n=== 处理完成 ===") + + print("\n=== 处理完成 ===") print(f"目标用户: {target_person_name} (QQ: {user_qq})") print(f"处理段落数: {total_segments_processed}") print(f"当前记忆点数: {len(final_points)}") print(f"是否有总体印象: {'是' if final_impression else '否'}") - + if final_points: print(f"最新记忆点: {final_points[-1][0][:50]}...") - - def display_chat_list(self, grouped_messages: Dict[str, List[Dict[str, Any]]]) -> None: - """显示群聊列表""" - print("\n找到以下群聊:") - print("=" * 60) - - for i, (chat_id, messages) in enumerate(grouped_messages.items(), 1): - first_msg = messages[0] - group_name = first_msg.get('chat_info_group_name', '私聊') - group_id = first_msg.get('chat_info_group_id', chat_id) - - # 计算时间范围 - start_time = datetime.fromtimestamp(messages[0]['time']).strftime('%Y-%m-%d') - end_time = datetime.fromtimestamp(messages[-1]['time']).strftime('%Y-%m-%d') - - print(f"{i:2d}. {group_name}") - print(f" 群ID: {group_id}") - print(f" 消息数: {len(messages)}") - print(f" 时间范围: {start_time} ~ {end_time}") - print("-" * 60) - - def get_user_selection(self, total_count: int) -> List[int]: - """获取用户选择的群聊编号""" - while True: - print(f"\n请选择要分析的群聊 (1-{total_count}):") - print("输入格式:") - print(" 单个: 1") - print(" 多个: 1,3,5") - print(" 范围: 1-3") - print(" 全部: all 或 a") - print(" 退出: quit 或 q") - - user_input = input("请输入选择: ").strip().lower() - - if user_input in ['quit', 'q']: - return [] - - if user_input in ['all', 'a']: - return list(range(1, total_count + 1)) - - try: - selected = [] - - # 处理逗号分隔的输入 - parts = user_input.split(',') - - for part in parts: - part = part.strip() - - if '-' in part: - # 处理范围输入 (如: 1-3) - start, end = part.split('-') - start_num = int(start.strip()) - end_num = int(end.strip()) - - if 1 <= start_num <= total_count and 1 <= end_num <= total_count and start_num <= end_num: - selected.extend(range(start_num, end_num + 1)) - else: - raise ValueError("范围超出有效范围") - else: - # 处理单个数字 - num = int(part) - if 1 <= num <= total_count: - selected.append(num) - else: - raise ValueError("数字超出有效范围") - - # 去重并排序 - selected = sorted(list(set(selected))) - - if selected: - return selected - else: - print("错误: 请输入有效的选择") - - except ValueError as e: - print(f"错误: 输入格式无效 - {e}") - print("请重新输入") - - def filter_selected_chats(self, grouped_messages: Dict[str, List[Dict[str, Any]]], selected_indices: List[int]) -> Dict[str, List[Dict[str, Any]]]: - """根据用户选择过滤群聊""" - chat_items = list(grouped_messages.items()) - selected_chats = {} - - for idx in selected_indices: - chat_id, messages = chat_items[idx - 1] # 转换为0基索引 - selected_chats[chat_id] = messages - - return selected_chats async def run(self): """运行脚本""" print("=== 消息检索分析脚本 ===") - + # 获取用户输入 user_qq = input("请输入用户QQ号: ").strip() if not user_qq: print("QQ号不能为空") return - + print("\n时间段选择:") print("1. 全部时间 (all)") print("2. 最近3个月 (3months)") print("3. 最近1个月 (1month)") print("4. 最近1周 (1week)") - + choice = input("请选择时间段 (1-4): ").strip() - time_periods = { - "1": "all", - "2": "3months", - "3": "1month", - "4": "1week" - } - + time_periods = {"1": "all", "2": "3months", "3": "1month", "4": "1week"} + if choice not in time_periods: print("选择无效") return - + time_period = time_periods[choice] - + print(f"\n开始处理用户 {user_qq} 在时间段 {time_period} 的消息...") - + # 连接数据库 try: db.connect(reuse_if_open=True) @@ -782,57 +788,59 @@ class MessageRetrievalScript: except Exception as e: print(f"数据库连接失败: {e}") return - + try: # 检索消息 grouped_messages = self.retrieve_messages(user_qq, time_period) - + if not grouped_messages: print("未找到任何消息") return - + # 显示群聊列表 - self.display_chat_list(grouped_messages) - + display_chat_list(grouped_messages) + # 获取用户选择 - selected_indices = self.get_user_selection(len(grouped_messages)) - + selected_indices = get_user_selection(len(grouped_messages)) + if not selected_indices: print("已取消操作") return - + # 过滤选中的群聊 - selected_chats = self.filter_selected_chats(grouped_messages, selected_indices) - + selected_chats = filter_selected_chats(grouped_messages, selected_indices) + # 显示选中的群聊 print(f"\n已选择 {len(selected_chats)} 个群聊进行分析:") - for i, (chat_id, messages) in enumerate(selected_chats.items(), 1): + for i, (_, messages) in enumerate(selected_chats.items(), 1): first_msg = messages[0] - group_name = first_msg.get('chat_info_group_name', '私聊') + group_name = first_msg.get("chat_info_group_name", "私聊") print(f" {i}. {group_name} ({len(messages)}条消息)") - + # 确认处理 - confirm = input(f"\n确认分析这些群聊吗? (y/n): ").strip().lower() - if confirm != 'y': + confirm = input("\n确认分析这些群聊吗? (y/n): ").strip().lower() + if confirm != "y": print("已取消操作") return - + # 处理分段消息并更新数据库 await self.process_segments_and_update_impression(user_qq, selected_chats) - + except Exception as e: print(f"处理过程中出现错误: {e}") import traceback + traceback.print_exc() finally: db.close() print("数据库连接已关闭") + def main(): """主函数""" script = MessageRetrievalScript() asyncio.run(script.run()) + if __name__ == "__main__": main() - diff --git a/src/chat/heart_flow/utils_chat.py b/src/chat/heart_flow/utils_chat.py index 7289db1a8..22581e482 100644 --- a/src/chat/heart_flow/utils_chat.py +++ b/src/chat/heart_flow/utils_chat.py @@ -1,7 +1,7 @@ from typing import Optional, Tuple, Dict from src.common.logger_manager import get_logger from src.chat.message_receive.chat_stream import chat_manager -from src.person_info.person_info import person_info_manager +from src.person_info.person_info import person_info_manager, PersonInfoManager logger = get_logger("heartflow_utils") @@ -47,7 +47,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]: # Try to fetch person info try: # Assume get_person_id is sync (as per original code), keep using to_thread - person_id = person_info_manager.get_person_id(platform, user_id) + person_id = PersonInfoManager.get_person_id(platform, user_id) person_name = None if person_id: # get_value is async, so await it directly diff --git a/src/chat/normal_chat/normal_chat_generator.py b/src/chat/normal_chat/normal_chat_generator.py index 41ac71492..65e60e963 100644 --- a/src/chat/normal_chat/normal_chat_generator.py +++ b/src/chat/normal_chat/normal_chat_generator.py @@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageThinking from src.chat.normal_chat.normal_prompt import prompt_builder from src.chat.utils.timer_calculator import Timer from src.common.logger_manager import get_logger -from src.person_info.person_info import person_info_manager +from src.person_info.person_info import person_info_manager, PersonInfoManager from src.chat.utils.utils import process_llm_response @@ -66,7 +66,7 @@ class NormalChatGenerator: enable_planner: bool = False, available_actions=None, ): - person_id = person_info_manager.get_person_id( + person_id = PersonInfoManager.get_person_id( message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id ) diff --git a/src/chat/normal_chat/willing/willing_manager.py b/src/chat/normal_chat/willing/willing_manager.py index 4080ae8e8..09f303a6e 100644 --- a/src/chat/normal_chat/willing/willing_manager.py +++ b/src/chat/normal_chat/willing/willing_manager.py @@ -96,7 +96,7 @@ class BaseWillingManager(ABC): self.logger: LoguruLogger = logger def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float): - person_id = person_info_manager.get_person_id(chat.platform, chat.user_info.user_id) + person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id) self.ongoing_messages[message.message_info.message_id] = WillingInfo( message=message, chat=chat, diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index da6ff5e58..d4b7d4646 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -4,7 +4,7 @@ import time # 导入 time 模块以获取当前时间 import random import re from src.common.message_repository import find_messages, count_messages -from src.person_info.person_info import person_info_manager +from src.person_info.person_info import person_info_manager, PersonInfoManager from src.chat.utils.utils import translate_timestamp_to_human_readable from rich.traceback import install from src.common.database.database_model import ActionRecords @@ -219,7 +219,7 @@ def _build_readable_messages_internal( if not all([platform, user_id, timestamp is not None]): continue - person_id = person_info_manager.get_person_id(platform, user_id) + person_id = PersonInfoManager.get_person_id(platform, user_id) # 根据 replace_bot_name 参数决定是否替换机器人名称 if replace_bot_name and user_id == global_config.bot.qq_account: person_name = f"{global_config.bot.nickname}(你)" @@ -241,7 +241,7 @@ def _build_readable_messages_internal( if match: aaa = match.group(1) bbb = match.group(2) - reply_person_id = person_info_manager.get_person_id(platform, bbb) + reply_person_id = PersonInfoManager.get_person_id(platform, bbb) reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name") if not reply_person_name: reply_person_name = aaa @@ -258,7 +258,7 @@ def _build_readable_messages_internal( new_content += content[last_end : m.start()] aaa = m.group(1) bbb = m.group(2) - at_person_id = person_info_manager.get_person_id(platform, bbb) + at_person_id = PersonInfoManager.get_person_id(platform, bbb) at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name") if not at_person_name: at_person_name = aaa @@ -572,7 +572,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: # print("SELF11111111111111") return "SELF" try: - person_id = person_info_manager.get_person_id(platform, user_id) + person_id = PersonInfoManager.get_person_id(platform, user_id) except Exception as _e: person_id = None if not person_id: @@ -673,7 +673,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: if not all([platform, user_id]) or user_id == global_config.bot.qq_account: continue - person_id = person_info_manager.get_person_id(platform, user_id) + person_id = PersonInfoManager.get_person_id(platform, user_id) # 只有当获取到有效 person_id 时才添加 if person_id: diff --git a/src/person_info/impression_update_task.py b/src/person_info/impression_update_task.py index 98b6ede36..480090163 100644 --- a/src/person_info/impression_update_task.py +++ b/src/person_info/impression_update_task.py @@ -1,9 +1,9 @@ from src.manager.async_task_manager import AsyncTask from src.common.logger_manager import get_logger +from src.person_info.person_info import PersonInfoManager from src.person_info.relationship_manager import relationship_manager from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp from src.config.config import global_config -from src.person_info.person_info import person_info_manager from src.chat.message_receive.chat_stream import chat_manager import time import random @@ -95,7 +95,7 @@ class ImpressionUpdateTask(AsyncTask): if msg["user_nickname"] == global_config.bot.nickname: continue - person_id = person_info_manager.get_person_id(msg["chat_info_platform"], msg["user_id"]) + person_id = PersonInfoManager.get_person_id(msg["chat_info_platform"], msg["user_id"]) if not person_id: logger.warning(f"未找到用户 {msg['user_nickname']} 的person_id") continue diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 19b53be1c..0029e6492 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -1,6 +1,6 @@ from src.common.logger_manager import get_logger import math -from src.person_info.person_info import person_info_manager +from src.person_info.person_info import person_info_manager, PersonInfoManager import time import random from src.llm_models.utils_model import LLMRequest @@ -91,7 +91,7 @@ class RelationshipManager: @staticmethod async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str): """判断是否认识某人""" - person_id = person_info_manager.get_person_id(platform, user_id) + person_id = PersonInfoManager.get_person_id(platform, user_id) # 生成唯一的 person_name unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname) data = { @@ -116,7 +116,7 @@ class RelationshipManager: if is_id: person_id = person else: - person_id = person_info_manager.get_person_id(person[0], person[1]) + person_id = PersonInfoManager.get_person_id(person[0], person[1]) person_name = await person_info_manager.get_value(person_id, "person_name") if not person_name or person_name == "none": @@ -198,7 +198,7 @@ class RelationshipManager: ) replace_user_id = msg.get("user_id") replace_platform = msg.get("chat_info_platform") - replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id) + replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id) replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name") # 跳过机器人自己 @@ -467,7 +467,7 @@ class RelationshipManager: for i, msg in enumerate(messages): user_id = msg.get("user_id") platform = msg.get("chat_info_platform") - person_id = person_info_manager.get_person_id(platform, user_id) + person_id = PersonInfoManager.get_person_id(platform, user_id) if person_id == target_person_id: target_indices.append(i)