diff --git a/scripts/analyze_group_similarity.py b/scripts/analyze_group_similarity.py index 86863a52c..7831b62bd 100644 --- a/scripts/analyze_group_similarity.py +++ b/scripts/analyze_group_similarity.py @@ -1,37 +1,38 @@ import json -import os from pathlib import Path import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import matplotlib.pyplot as plt import seaborn as sns -import networkx as nx -import matplotlib as mpl import sqlite3 # 设置中文字体 -plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 使用微软雅黑 -plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 -plt.rcParams['font.family'] = 'sans-serif' +plt.rcParams["font.sans-serif"] = ["Microsoft YaHei"] # 使用微软雅黑 +plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号 +plt.rcParams["font.family"] = "sans-serif" # 获取脚本所在目录 SCRIPT_DIR = Path(__file__).parent + def get_group_name(stream_id): """从数据库中获取群组名称""" - conn = sqlite3.connect('data/maibot.db') + conn = sqlite3.connect("data/maibot.db") cursor = conn.cursor() - - cursor.execute(''' + + cursor.execute( + """ SELECT group_name, user_nickname, platform FROM chat_streams WHERE stream_id = ? - ''', (stream_id,)) - + """, + (stream_id,), + ) + result = cursor.fetchone() conn.close() - + if result: group_name, user_nickname, platform = result if group_name: @@ -42,139 +43,148 @@ def get_group_name(stream_id): return f"{platform}-{stream_id[:8]}" return stream_id + def load_group_data(group_dir): """加载单个群组的数据""" json_path = Path(group_dir) / "expressions.json" if not json_path.exists(): return [], [], [] - - with open(json_path, 'r', encoding='utf-8') as f: + + with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) - + situations = [] styles = [] combined = [] - + for item in data: - count = item['count'] - situations.extend([item['situation']] * count) - styles.extend([item['style']] * count) + count = item["count"] + situations.extend([item["situation"]] * count) + styles.extend([item["style"]] * count) combined.extend([f"{item['situation']} {item['style']}"] * count) - + return situations, styles, combined + def analyze_group_similarity(): # 获取所有群组目录 base_dir = Path("data/expression/learnt_style") group_dirs = [d for d in base_dir.iterdir() if d.is_dir()] group_ids = [d.name for d in group_dirs] - + # 获取群组名称 group_names = [get_group_name(group_id) for group_id in group_ids] - + # 加载所有群组的数据 group_situations = [] group_styles = [] group_combined = [] - + for d in group_dirs: situations, styles, combined = load_group_data(d) - group_situations.append(' '.join(situations)) - group_styles.append(' '.join(styles)) - group_combined.append(' '.join(combined)) - + group_situations.append(" ".join(situations)) + group_styles.append(" ".join(styles)) + group_combined.append(" ".join(combined)) + # 创建TF-IDF向量化器 vectorizer = TfidfVectorizer() - + # 计算三种相似度矩阵 situation_matrix = cosine_similarity(vectorizer.fit_transform(group_situations)) style_matrix = cosine_similarity(vectorizer.fit_transform(group_styles)) combined_matrix = cosine_similarity(vectorizer.fit_transform(group_combined)) - + # 对相似度矩阵进行对数变换 log_situation_matrix = np.log1p(situation_matrix) log_style_matrix = np.log1p(style_matrix) log_combined_matrix = np.log1p(combined_matrix) - + # 创建一个大图,包含三个子图 plt.figure(figsize=(45, 12)) - + # 场景相似度热力图 plt.subplot(1, 3, 1) - sns.heatmap(log_situation_matrix, - xticklabels=group_names, - yticklabels=group_names, - cmap='YlOrRd', - annot=True, - fmt='.2f', - vmin=0, - vmax=np.log1p(0.2)) - plt.title('群组场景相似度热力图 (对数变换)') - plt.xticks(rotation=45, ha='right') - + sns.heatmap( + log_situation_matrix, + xticklabels=group_names, + yticklabels=group_names, + cmap="YlOrRd", + annot=True, + fmt=".2f", + vmin=0, + vmax=np.log1p(0.2), + ) + plt.title("群组场景相似度热力图 (对数变换)") + plt.xticks(rotation=45, ha="right") + # 表达方式相似度热力图 plt.subplot(1, 3, 2) - sns.heatmap(log_style_matrix, - xticklabels=group_names, - yticklabels=group_names, - cmap='YlOrRd', - annot=True, - fmt='.2f', - vmin=0, - vmax=np.log1p(0.2)) - plt.title('群组表达方式相似度热力图 (对数变换)') - plt.xticks(rotation=45, ha='right') - + sns.heatmap( + log_style_matrix, + xticklabels=group_names, + yticklabels=group_names, + cmap="YlOrRd", + annot=True, + fmt=".2f", + vmin=0, + vmax=np.log1p(0.2), + ) + plt.title("群组表达方式相似度热力图 (对数变换)") + plt.xticks(rotation=45, ha="right") + # 组合相似度热力图 plt.subplot(1, 3, 3) - sns.heatmap(log_combined_matrix, - xticklabels=group_names, - yticklabels=group_names, - cmap='YlOrRd', - annot=True, - fmt='.2f', - vmin=0, - vmax=np.log1p(0.2)) - plt.title('群组场景+表达方式相似度热力图 (对数变换)') - plt.xticks(rotation=45, ha='right') - + sns.heatmap( + log_combined_matrix, + xticklabels=group_names, + yticklabels=group_names, + cmap="YlOrRd", + annot=True, + fmt=".2f", + vmin=0, + vmax=np.log1p(0.2), + ) + plt.title("群组场景+表达方式相似度热力图 (对数变换)") + plt.xticks(rotation=45, ha="right") + plt.tight_layout() - plt.savefig(SCRIPT_DIR / 'group_similarity_heatmaps.png', dpi=300, bbox_inches='tight') + plt.savefig(SCRIPT_DIR / "group_similarity_heatmaps.png", dpi=300, bbox_inches="tight") plt.close() - + # 保存匹配详情到文本文件 - with open(SCRIPT_DIR / 'group_similarity_details.txt', 'w', encoding='utf-8') as f: - f.write('群组相似度详情\n') - f.write('=' * 50 + '\n\n') - + with open(SCRIPT_DIR / "group_similarity_details.txt", "w", encoding="utf-8") as f: + f.write("群组相似度详情\n") + f.write("=" * 50 + "\n\n") + for i in range(len(group_ids)): - for j in range(i+1, len(group_ids)): + for j in range(i + 1, len(group_ids)): if log_combined_matrix[i][j] > np.log1p(0.05): - f.write(f'群组1: {group_names[i]}\n') - f.write(f'群组2: {group_names[j]}\n') - f.write(f'场景相似度: {situation_matrix[i][j]:.4f}\n') - f.write(f'表达方式相似度: {style_matrix[i][j]:.4f}\n') - f.write(f'组合相似度: {combined_matrix[i][j]:.4f}\n') - + f.write(f"群组1: {group_names[i]}\n") + f.write(f"群组2: {group_names[j]}\n") + f.write(f"场景相似度: {situation_matrix[i][j]:.4f}\n") + f.write(f"表达方式相似度: {style_matrix[i][j]:.4f}\n") + f.write(f"组合相似度: {combined_matrix[i][j]:.4f}\n") + # 获取两个群组的数据 situations1, styles1, _ = load_group_data(group_dirs[i]) situations2, styles2, _ = load_group_data(group_dirs[j]) - + # 找出共同的场景 common_situations = set(situations1) & set(situations2) if common_situations: - f.write('\n共同场景:\n') + f.write("\n共同场景:\n") for situation in common_situations: - f.write(f'- {situation}\n') - + f.write(f"- {situation}\n") + # 找出共同的表达方式 common_styles = set(styles1) & set(styles2) if common_styles: - f.write('\n共同表达方式:\n') + f.write("\n共同表达方式:\n") for style in common_styles: - f.write(f'- {style}\n') - - f.write('\n' + '-' * 50 + '\n\n') + f.write(f"- {style}\n") + + f.write("\n" + "-" * 50 + "\n\n") + if __name__ == "__main__": analyze_group_similarity() diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py index da34c7d44..1a1793f40 100644 --- a/scripts/mongodb_to_sqlite.py +++ b/scripts/mongodb_to_sqlite.py @@ -32,7 +32,6 @@ from rich.panel import Panel from src.common.database.database import db from src.common.database.database_model import ( ChatStreams, - LLMUsage, Emoji, Messages, Images, diff --git a/scripts/preview_expressions.py b/scripts/preview_expressions.py index 0eebfb442..1e71120d8 100644 --- a/scripts/preview_expressions.py +++ b/scripts/preview_expressions.py @@ -8,162 +8,174 @@ import matplotlib.pyplot as plt from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity -import numpy as np from collections import defaultdict + class ExpressionViewer: def __init__(self, root): self.root = root self.root.title("表达方式预览器") self.root.geometry("1200x800") - + # 创建主框架 self.main_frame = ttk.Frame(root) self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) - + # 创建左侧控制面板 self.control_frame = ttk.Frame(self.main_frame) self.control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 10)) - + # 创建搜索框 self.search_frame = ttk.Frame(self.control_frame) self.search_frame.pack(fill=tk.X, pady=(0, 10)) - + self.search_var = tk.StringVar() - self.search_var.trace('w', self.filter_expressions) + self.search_var.trace("w", self.filter_expressions) self.search_entry = ttk.Entry(self.search_frame, textvariable=self.search_var) self.search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) ttk.Label(self.search_frame, text="搜索:").pack(side=tk.LEFT, padx=(0, 5)) - + # 创建文件选择下拉框 self.file_var = tk.StringVar() self.file_combo = ttk.Combobox(self.search_frame, textvariable=self.file_var) self.file_combo.pack(side=tk.LEFT, padx=5) - self.file_combo.bind('<>', self.load_file) - + self.file_combo.bind("<>", self.load_file) + # 创建排序选项 self.sort_frame = ttk.LabelFrame(self.control_frame, text="排序选项") self.sort_frame.pack(fill=tk.X, pady=5) - + self.sort_var = tk.StringVar(value="count") - ttk.Radiobutton(self.sort_frame, text="按计数排序", variable=self.sort_var, - value="count", command=self.apply_sort).pack(anchor=tk.W) - ttk.Radiobutton(self.sort_frame, text="按情境排序", variable=self.sort_var, - value="situation", command=self.apply_sort).pack(anchor=tk.W) - ttk.Radiobutton(self.sort_frame, text="按风格排序", variable=self.sort_var, - value="style", command=self.apply_sort).pack(anchor=tk.W) - + ttk.Radiobutton( + self.sort_frame, text="按计数排序", variable=self.sort_var, value="count", command=self.apply_sort + ).pack(anchor=tk.W) + ttk.Radiobutton( + self.sort_frame, text="按情境排序", variable=self.sort_var, value="situation", command=self.apply_sort + ).pack(anchor=tk.W) + ttk.Radiobutton( + self.sort_frame, text="按风格排序", variable=self.sort_var, value="style", command=self.apply_sort + ).pack(anchor=tk.W) + # 创建分群选项 self.group_frame = ttk.LabelFrame(self.control_frame, text="分群选项") self.group_frame.pack(fill=tk.X, pady=5) - + self.group_var = tk.StringVar(value="none") - ttk.Radiobutton(self.group_frame, text="不分群", variable=self.group_var, - value="none", command=self.apply_grouping).pack(anchor=tk.W) - ttk.Radiobutton(self.group_frame, text="按情境分群", variable=self.group_var, - value="situation", command=self.apply_grouping).pack(anchor=tk.W) - ttk.Radiobutton(self.group_frame, text="按风格分群", variable=self.group_var, - value="style", command=self.apply_grouping).pack(anchor=tk.W) - + ttk.Radiobutton( + self.group_frame, text="不分群", variable=self.group_var, value="none", command=self.apply_grouping + ).pack(anchor=tk.W) + ttk.Radiobutton( + self.group_frame, text="按情境分群", variable=self.group_var, value="situation", command=self.apply_grouping + ).pack(anchor=tk.W) + ttk.Radiobutton( + self.group_frame, text="按风格分群", variable=self.group_var, value="style", command=self.apply_grouping + ).pack(anchor=tk.W) + # 创建相似度阈值滑块 self.similarity_frame = ttk.LabelFrame(self.control_frame, text="相似度设置") self.similarity_frame.pack(fill=tk.X, pady=5) - + self.similarity_var = tk.DoubleVar(value=0.5) - self.similarity_scale = ttk.Scale(self.similarity_frame, from_=0.0, to=1.0, - variable=self.similarity_var, orient=tk.HORIZONTAL, - command=self.update_similarity) + self.similarity_scale = ttk.Scale( + self.similarity_frame, + from_=0.0, + to=1.0, + variable=self.similarity_var, + orient=tk.HORIZONTAL, + command=self.update_similarity, + ) self.similarity_scale.pack(fill=tk.X, padx=5, pady=5) ttk.Label(self.similarity_frame, text="相似度阈值: 0.5").pack() - + # 创建显示选项 self.view_frame = ttk.LabelFrame(self.control_frame, text="显示选项") self.view_frame.pack(fill=tk.X, pady=5) - + self.show_graph_var = tk.BooleanVar(value=True) - ttk.Checkbutton(self.view_frame, text="显示关系图", variable=self.show_graph_var, - command=self.toggle_graph).pack(anchor=tk.W) - + ttk.Checkbutton( + self.view_frame, text="显示关系图", variable=self.show_graph_var, command=self.toggle_graph + ).pack(anchor=tk.W) + # 创建右侧内容区域 self.content_frame = ttk.Frame(self.main_frame) self.content_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) - + # 创建文本显示区域 self.text_area = tk.Text(self.content_frame, wrap=tk.WORD) self.text_area.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - + # 添加滚动条 scrollbar = ttk.Scrollbar(self.text_area, command=self.text_area.yview) scrollbar.pack(side=tk.RIGHT, fill=tk.Y) self.text_area.config(yscrollcommand=scrollbar.set) - + # 创建图形显示区域 self.graph_frame = ttk.Frame(self.content_frame) self.graph_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True) - + # 初始化数据 self.current_data = [] self.graph = nx.Graph() self.canvas = None - + # 加载文件列表 self.load_file_list() - + def load_file_list(self): expression_dir = Path("data/expression") files = [] for root, _, filenames in os.walk(expression_dir): for filename in filenames: - if filename.endswith('.json'): + if filename.endswith(".json"): rel_path = os.path.relpath(os.path.join(root, filename), expression_dir) files.append(rel_path) - - self.file_combo['values'] = files + + self.file_combo["values"] = files if files: self.file_combo.set(files[0]) self.load_file(None) - + def load_file(self, event): selected_file = self.file_var.get() if not selected_file: return - + file_path = os.path.join("data/expression", selected_file) try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: self.current_data = json.load(f) - + self.apply_sort() self.update_similarity() - + except Exception as e: self.text_area.delete(1.0, tk.END) self.text_area.insert(tk.END, f"加载文件时出错: {str(e)}") - + def apply_sort(self): if not self.current_data: return - + sort_key = self.sort_var.get() reverse = sort_key == "count" - + self.current_data.sort(key=lambda x: x.get(sort_key, ""), reverse=reverse) self.apply_grouping() - + def apply_grouping(self): if not self.current_data: return - + group_key = self.group_var.get() if group_key == "none": self.display_data(self.current_data) return - + grouped_data = defaultdict(list) for item in self.current_data: key = item.get(group_key, "未分类") grouped_data[key].append(item) - + self.text_area.delete(1.0, tk.END) for group, items in grouped_data.items(): self.text_area.insert(tk.END, f"\n=== {group} ===\n\n") @@ -172,7 +184,7 @@ class ExpressionViewer: self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n") self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n") self.text_area.insert(tk.END, "-" * 50 + "\n") - + def display_data(self, data): self.text_area.delete(1.0, tk.END) for item in data: @@ -180,59 +192,58 @@ class ExpressionViewer: self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n") self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n") self.text_area.insert(tk.END, "-" * 50 + "\n") - + def update_similarity(self, *args): if not self.current_data: return - + threshold = self.similarity_var.get() self.similarity_frame.winfo_children()[-1].config(text=f"相似度阈值: {threshold:.2f}") - + # 计算相似度 texts = [f"{item['situation']} {item['style']}" for item in self.current_data] vectorizer = TfidfVectorizer() tfidf_matrix = vectorizer.fit_transform(texts) similarity_matrix = cosine_similarity(tfidf_matrix) - + # 创建图 self.graph.clear() for i, item in enumerate(self.current_data): self.graph.add_node(i, label=f"{item['situation']}\n{item['style']}") - + # 添加边 for i in range(len(self.current_data)): for j in range(i + 1, len(self.current_data)): if similarity_matrix[i, j] > threshold: self.graph.add_edge(i, j, weight=similarity_matrix[i, j]) - + if self.show_graph_var.get(): self.draw_graph() - + def draw_graph(self): if self.canvas: self.canvas.get_tk_widget().destroy() - + fig = plt.figure(figsize=(8, 6)) pos = nx.spring_layout(self.graph) - + # 绘制节点 - nx.draw_networkx_nodes(self.graph, pos, node_color='lightblue', - node_size=1000, alpha=0.6) - + nx.draw_networkx_nodes(self.graph, pos, node_color="lightblue", node_size=1000, alpha=0.6) + # 绘制边 nx.draw_networkx_edges(self.graph, pos, alpha=0.4) - + # 添加标签 - labels = nx.get_node_attributes(self.graph, 'label') + labels = nx.get_node_attributes(self.graph, "label") nx.draw_networkx_labels(self.graph, pos, labels, font_size=8) - + plt.title("表达方式关系图") - plt.axis('off') - + plt.axis("off") + self.canvas = FigureCanvasTkAgg(fig, master=self.graph_frame) self.canvas.draw() self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True) - + def toggle_graph(self): if self.show_graph_var.get(): self.draw_graph() @@ -240,26 +251,28 @@ class ExpressionViewer: if self.canvas: self.canvas.get_tk_widget().destroy() self.canvas = None - + def filter_expressions(self, *args): search_text = self.search_var.get().lower() if not search_text: self.apply_sort() return - + filtered_data = [] for item in self.current_data: - situation = item.get('situation', '').lower() - style = item.get('style', '').lower() + situation = item.get("situation", "").lower() + style = item.get("style", "").lower() if search_text in situation or search_text in style: filtered_data.append(item) - + self.display_data(filtered_data) + def main(): root = tk.Tk() - app = ExpressionViewer(root) + # app = ExpressionViewer(root) root.mainloop() + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 2c2081fd5..68495df61 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -602,9 +602,10 @@ class EmojiManager: continue # 检查是否需要处理表情包(数量超过最大值或不足) - if global_config.emoji.steal_emoji and ((self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or ( - self.emoji_num < self.emoji_num_max - )): + if global_config.emoji.steal_emoji and ( + (self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) + or (self.emoji_num < self.emoji_num_max) + ): try: # 获取目录下所有图片文件 files_to_process = [ diff --git a/src/chat/focus_chat/info_processors/working_memory_processor.py b/src/chat/focus_chat/info_processors/working_memory_processor.py index 616590fa5..c05826c82 100644 --- a/src/chat/focus_chat/info_processors/working_memory_processor.py +++ b/src/chat/focus_chat/info_processors/working_memory_processor.py @@ -120,8 +120,10 @@ class WorkingMemoryProcessor(BaseProcessor): memory_str=memory_choose_str, ) + # print(f"prompt: {prompt}") + # 调用LLM处理记忆 content = "" try: diff --git a/src/chat/focus_chat/planners/actions/plugin_action.py b/src/chat/focus_chat/planners/actions/plugin_action.py index f68aa5834..867ea948b 100644 --- a/src/chat/focus_chat/planners/actions/plugin_action.py +++ b/src/chat/focus_chat/planners/actions/plugin_action.py @@ -194,7 +194,7 @@ class PluginAction(BaseAction): # 获取锚定消息(如果有) observations = self._services.get("observations", []) - + # 查找 ChattingObservation 实例 chatting_observation = None for obs in observations: diff --git a/src/chat/focus_chat/planners/planner_simple.py b/src/chat/focus_chat/planners/planner_simple.py index 2b943be6c..8250e8a97 100644 --- a/src/chat/focus_chat/planners/planner_simple.py +++ b/src/chat/focus_chat/planners/planner_simple.py @@ -226,14 +226,14 @@ class ActionPlanner(BasePlanner): action_data[key] = value action_data["identity"] = self_info - + extra_info_block = "\n".join(extra_info) extra_info_block += f"\n{structured_info}" if extra_info or structured_info: extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策" else: extra_info_block = "" - + action_data["extra_info_block"] = extra_info_block # 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data @@ -272,9 +272,6 @@ class ActionPlanner(BasePlanner): ) action_result = {"action_type": action, "action_data": action_data, "reasoning": reasoning} - - - plan_result = { "action_result": action_result, diff --git a/src/chat/focus_chat/working_memory/memory_item.py b/src/chat/focus_chat/working_memory/memory_item.py index 14161f92d..dc8355252 100644 --- a/src/chat/focus_chat/working_memory/memory_item.py +++ b/src/chat/focus_chat/working_memory/memory_item.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List, Optional, Set, Tuple +from typing import Dict, Any, Tuple import time import random import string diff --git a/src/chat/focus_chat/working_memory/memory_manager.py b/src/chat/focus_chat/working_memory/memory_manager.py index 21aa94707..1e8ae4912 100644 --- a/src/chat/focus_chat/working_memory/memory_manager.py +++ b/src/chat/focus_chat/working_memory/memory_manager.py @@ -286,110 +286,110 @@ class MemoryManager: logger.error(f"生成总结时出错: {str(e)}") return default_summary -# async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]: -# """ -# 对记忆进行精简操作,根据要求修改要点、总结和概括 + # async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]: + # """ + # 对记忆进行精简操作,根据要求修改要点、总结和概括 -# Args: -# memory_id: 记忆ID -# requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点 + # Args: + # memory_id: 记忆ID + # requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点 -# Returns: -# 修改后的记忆总结字典 -# """ -# # 获取指定ID的记忆项 -# logger.info(f"精简记忆: {memory_id}") -# memory_item = self.get_by_id(memory_id) -# if not memory_item: -# raise ValueError(f"未找到ID为{memory_id}的记忆项") + # Returns: + # 修改后的记忆总结字典 + # """ + # # 获取指定ID的记忆项 + # logger.info(f"精简记忆: {memory_id}") + # memory_item = self.get_by_id(memory_id) + # if not memory_item: + # raise ValueError(f"未找到ID为{memory_id}的记忆项") -# # 增加精简次数 -# memory_item.increase_compress_count() + # # 增加精简次数 + # memory_item.increase_compress_count() -# summary = memory_item.summary + # summary = memory_item.summary -# # 使用LLM根据要求对总结、概括和要点进行精简修改 -# prompt = f""" -# 请根据以下要求,对记忆内容的主题和关键要点进行精简,模拟记忆的遗忘过程: -# 要求:{requirements} -# 你可以随机对关键要点进行压缩,模糊或者丢弃,修改后,同样修改主题 + # # 使用LLM根据要求对总结、概括和要点进行精简修改 + # prompt = f""" + # 请根据以下要求,对记忆内容的主题和关键要点进行精简,模拟记忆的遗忘过程: + # 要求:{requirements} + # 你可以随机对关键要点进行压缩,模糊或者丢弃,修改后,同样修改主题 -# 目前主题:{summary["brief"]} + # 目前主题:{summary["brief"]} -# 目前关键要点: -# {chr(10).join([f"- {point}" for point in summary.get("points", [])])} + # 目前关键要点: + # {chr(10).join([f"- {point}" for point in summary.get("points", [])])} -# 请生成修改后的主题和关键要点,遵循以下格式: -# ```json -# {{ -# "brief": "修改后的主题(20字以内)", -# "points": [ -# "修改后的要点", -# "修改后的要点" -# ] -# }} -# ``` -# 请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 -# """ -# # 定义默认的精简结果 -# default_refined = { -# "brief": summary["brief"], -# "points": summary.get("points", ["未知的要点"])[:1], # 默认只保留第一个要点 -# } + # 请生成修改后的主题和关键要点,遵循以下格式: + # ```json + # {{ + # "brief": "修改后的主题(20字以内)", + # "points": [ + # "修改后的要点", + # "修改后的要点" + # ] + # }} + # ``` + # 请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 + # """ + # # 定义默认的精简结果 + # default_refined = { + # "brief": summary["brief"], + # "points": summary.get("points", ["未知的要点"])[:1], # 默认只保留第一个要点 + # } -# try: -# # 调用LLM修改总结、概括和要点 -# response, _ = await self.llm_summarizer.generate_response_async(prompt) -# logger.debug(f"精简记忆响应: {response}") -# # 使用repair_json处理响应 -# try: -# # 修复JSON格式 -# fixed_json_string = repair_json(response) + # try: + # # 调用LLM修改总结、概括和要点 + # response, _ = await self.llm_summarizer.generate_response_async(prompt) + # logger.debug(f"精简记忆响应: {response}") + # # 使用repair_json处理响应 + # try: + # # 修复JSON格式 + # fixed_json_string = repair_json(response) -# # 将修复后的字符串解析为Python对象 -# if isinstance(fixed_json_string, str): -# try: -# refined_data = json.loads(fixed_json_string) -# except json.JSONDecodeError as decode_error: -# logger.error(f"JSON解析错误: {str(decode_error)}") -# refined_data = default_refined -# else: -# # 如果repair_json直接返回了字典对象,直接使用 -# refined_data = fixed_json_string + # # 将修复后的字符串解析为Python对象 + # if isinstance(fixed_json_string, str): + # try: + # refined_data = json.loads(fixed_json_string) + # except json.JSONDecodeError as decode_error: + # logger.error(f"JSON解析错误: {str(decode_error)}") + # refined_data = default_refined + # else: + # # 如果repair_json直接返回了字典对象,直接使用 + # refined_data = fixed_json_string -# # 确保是字典类型 -# if not isinstance(refined_data, dict): -# logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}") -# refined_data = default_refined + # # 确保是字典类型 + # if not isinstance(refined_data, dict): + # logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}") + # refined_data = default_refined -# # 更新总结 -# summary["brief"] = refined_data.get("brief", "主题未知的记忆") + # # 更新总结 + # summary["brief"] = refined_data.get("brief", "主题未知的记忆") -# # 更新关键要点 -# points = refined_data.get("points", []) -# if isinstance(points, list) and points: -# # 确保所有要点都是字符串 -# summary["points"] = [str(point) for point in points if point is not None] -# else: -# # 如果points不是列表或为空,使用默认值 -# summary["points"] = ["主要要点已遗忘"] + # # 更新关键要点 + # points = refined_data.get("points", []) + # if isinstance(points, list) and points: + # # 确保所有要点都是字符串 + # summary["points"] = [str(point) for point in points if point is not None] + # else: + # # 如果points不是列表或为空,使用默认值 + # summary["points"] = ["主要要点已遗忘"] -# except Exception as e: -# logger.error(f"精简记忆出错: {str(e)}") -# traceback.print_exc() + # except Exception as e: + # logger.error(f"精简记忆出错: {str(e)}") + # traceback.print_exc() -# # 出错时使用简化的默认精简 -# summary["brief"] = summary["brief"] + " (已简化)" -# summary["points"] = summary.get("points", ["未知的要点"])[:1] + # # 出错时使用简化的默认精简 + # summary["brief"] = summary["brief"] + " (已简化)" + # summary["points"] = summary.get("points", ["未知的要点"])[:1] -# except Exception as e: -# logger.error(f"精简记忆调用LLM出错: {str(e)}") -# traceback.print_exc() + # except Exception as e: + # logger.error(f"精简记忆调用LLM出错: {str(e)}") + # traceback.print_exc() -# # 更新原记忆项的总结 -# memory_item.set_summary(summary) + # # 更新原记忆项的总结 + # memory_item.set_summary(summary) -# return memory_item + # return memory_item def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool: """ diff --git a/src/chat/focus_chat/working_memory/working_memory.py b/src/chat/focus_chat/working_memory/working_memory.py index 6f3510709..496fdc37e 100644 --- a/src/chat/focus_chat/working_memory/working_memory.py +++ b/src/chat/focus_chat/working_memory/working_memory.py @@ -1,6 +1,5 @@ from typing import List, Any, Optional import asyncio -import random from src.common.logger_manager import get_logger from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 43cc1fef6..300563658 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -24,7 +24,6 @@ from rich.traceback import install from ...config.config import global_config from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入 -from peewee import Case install(extra_lines=3) @@ -217,7 +216,7 @@ class Hippocampus: """计算节点的特征值""" if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] - + # 使用集合来去重,避免排序 unique_items = set(str(item) for item in memory_items) # 使用frozenset来保证顺序一致性 @@ -816,7 +815,7 @@ class EntorhinalCortex: timestamps = sample_scheduler.get_timestamp_array() # 使用 translate_timestamp_to_human_readable 并指定 mode="normal" readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps] - for timestamp, readable_timestamp in zip(timestamps, readable_timestamps): + for _, readable_timestamp in zip(timestamps, readable_timestamps): logger.debug(f"回忆往事: {readable_timestamp}") chat_samples = [] for timestamp in timestamps: @@ -843,16 +842,20 @@ class EntorhinalCortex: # 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds timestamp_start = target_timestamp timestamp_end = target_timestamp + time_window_seconds - + chosen_message = get_raw_msg_by_timestamp( timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest" ) - + if chosen_message: chat_id = chosen_message[0].get("chat_id") messages = get_raw_msg_by_timestamp_with_chat( - timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest", chat_id=chat_id + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + limit=chat_size, + limit_mode="earliest", + chat_id=chat_id, ) if messages: @@ -885,7 +888,7 @@ class EntorhinalCortex: async def sync_memory_to_db(self): """将记忆图同步到数据库""" start_time = time.time() - + # 获取数据库中所有节点和内存中所有节点 db_load_start = time.time() db_nodes = {node.concept: node for node in GraphNodes.select()} @@ -943,13 +946,15 @@ class EntorhinalCortex: if concept not in db_nodes: # 数据库中缺少的节点,添加到创建列表 - nodes_to_create.append({ - 'concept': concept, - 'memory_items': memory_items_json, - 'hash': memory_hash, - 'created_time': created_time, - 'last_modified': last_modified - }) + nodes_to_create.append( + { + "concept": concept, + "memory_items": memory_items_json, + "hash": memory_hash, + "created_time": created_time, + "last_modified": last_modified, + } + ) logger.debug(f"[同步] 准备创建节点: {concept}, memory_items长度: {len(memory_items)}") else: # 获取数据库中节点的特征值 @@ -958,12 +963,14 @@ class EntorhinalCortex: # 如果特征值不同,则添加到更新列表 if db_hash != memory_hash: - nodes_to_update.append({ - 'concept': concept, - 'memory_items': memory_items_json, - 'hash': memory_hash, - 'last_modified': last_modified - }) + nodes_to_update.append( + { + "concept": concept, + "memory_items": memory_items_json, + "hash": memory_hash, + "last_modified": last_modified, + } + ) # 检查需要删除的节点 memory_concepts = {concept for concept, _ in memory_nodes} @@ -972,7 +979,9 @@ class EntorhinalCortex: node_process_end = time.time() logger.info(f"[同步] 处理节点数据耗时: {node_process_end - node_process_start:.2f}秒") - logger.info(f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点") + logger.info( + f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点" + ) # 异步批量创建新节点 node_create_start = time.time() @@ -981,22 +990,24 @@ class EntorhinalCortex: # 验证所有要创建的节点数据 valid_nodes_to_create = [] for node_data in nodes_to_create: - if not node_data.get('memory_items'): + if not node_data.get("memory_items"): logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空") continue try: # 验证 JSON 字符串 - json.loads(node_data['memory_items']) + json.loads(node_data["memory_items"]) valid_nodes_to_create.append(node_data) except json.JSONDecodeError: - logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串") + logger.warning( + f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串" + ) continue if valid_nodes_to_create: # 使用异步批量插入 batch_size = 100 for i in range(0, len(valid_nodes_to_create), batch_size): - batch = valid_nodes_to_create[i:i + batch_size] + batch = valid_nodes_to_create[i : i + batch_size] await self._async_batch_create_nodes(batch) logger.info(f"[同步] 成功创建 {len(valid_nodes_to_create)} 个节点") else: @@ -1006,21 +1017,25 @@ class EntorhinalCortex: # 尝试逐个创建以找出问题节点 for node_data in nodes_to_create: try: - if not node_data.get('memory_items'): + if not node_data.get("memory_items"): logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空") continue try: - json.loads(node_data['memory_items']) + json.loads(node_data["memory_items"]) except json.JSONDecodeError: - logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串") + logger.warning( + f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串" + ) continue await self._async_create_node(node_data) except Exception as e: logger.error(f"[同步] 创建节点失败: {node_data['concept']}, 错误: {e}") # 从图中移除问题节点 - self.memory_graph.G.remove_node(node_data['concept']) + self.memory_graph.G.remove_node(node_data["concept"]) node_create_end = time.time() - logger.info(f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)") + logger.info( + f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)" + ) # 异步批量更新节点 node_update_start = time.time() @@ -1028,30 +1043,32 @@ class EntorhinalCortex: # 按批次更新节点,每批100个 batch_size = 100 for i in range(0, len(nodes_to_update), batch_size): - batch = nodes_to_update[i:i + batch_size] + batch = nodes_to_update[i : i + batch_size] try: # 验证批次中的每个节点数据 valid_batch = [] for node_data in batch: # 确保 memory_items 不为空且是有效的 JSON 字符串 - if not node_data.get('memory_items'): + if not node_data.get("memory_items"): logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 为空") continue try: # 验证 JSON 字符串是否有效 - json.loads(node_data['memory_items']) + json.loads(node_data["memory_items"]) valid_batch.append(node_data) except json.JSONDecodeError: - logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串") + logger.warning( + f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串" + ) continue - + if not valid_batch: - logger.warning(f"[同步] 批次 {i//batch_size + 1} 没有有效的节点可以更新") + logger.warning(f"[同步] 批次 {i // batch_size + 1} 没有有效的节点可以更新") continue # 异步批量更新节点 await self._async_batch_update_nodes(valid_batch) - logger.debug(f"[同步] 成功更新批次 {i//batch_size + 1} 中的 {len(valid_batch)} 个节点") + logger.debug(f"[同步] 成功更新批次 {i // batch_size + 1} 中的 {len(valid_batch)} 个节点") except Exception as e: logger.error(f"[同步] 批量更新节点失败: {e}") # 如果批量更新失败,尝试逐个更新 @@ -1061,17 +1078,21 @@ class EntorhinalCortex: except Exception as e: logger.error(f"[同步] 更新节点失败: {node_data['concept']}, 错误: {e}") # 从图中移除问题节点 - self.memory_graph.G.remove_node(node_data['concept']) + self.memory_graph.G.remove_node(node_data["concept"]) node_update_end = time.time() - logger.info(f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)") + logger.info( + f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)" + ) # 异步删除不存在的节点 node_delete_start = time.time() if nodes_to_delete: await self._async_delete_nodes(nodes_to_delete) node_delete_end = time.time() - logger.info(f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)") + logger.info( + f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)" + ) # 处理边的信息 edge_load_start = time.time() @@ -1106,24 +1127,28 @@ class EntorhinalCortex: if edge_key not in db_edge_dict: # 添加新边到创建列表 - edges_to_create.append({ - 'source': source, - 'target': target, - 'strength': strength, - 'hash': edge_hash, - 'created_time': created_time, - 'last_modified': last_modified - }) + edges_to_create.append( + { + "source": source, + "target": target, + "strength": strength, + "hash": edge_hash, + "created_time": created_time, + "last_modified": last_modified, + } + ) else: # 检查边的特征值是否变化 if db_edge_dict[edge_key]["hash"] != edge_hash: - edges_to_update.append({ - 'source': source, - 'target': target, - 'strength': strength, - 'hash': edge_hash, - 'last_modified': last_modified - }) + edges_to_update.append( + { + "source": source, + "target": target, + "strength": strength, + "hash": edge_hash, + "last_modified": last_modified, + } + ) edge_process_end = time.time() logger.info(f"[同步] 处理边数据耗时: {edge_process_end - edge_process_start:.2f}秒") @@ -1132,20 +1157,24 @@ class EntorhinalCortex: if edges_to_create: batch_size = 100 for i in range(0, len(edges_to_create), batch_size): - batch = edges_to_create[i:i + batch_size] + batch = edges_to_create[i : i + batch_size] await self._async_batch_create_edges(batch) edge_create_end = time.time() - logger.info(f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)") + logger.info( + f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)" + ) # 异步批量更新边 edge_update_start = time.time() if edges_to_update: batch_size = 100 for i in range(0, len(edges_to_update), batch_size): - batch = edges_to_update[i:i + batch_size] + batch = edges_to_update[i : i + batch_size] await self._async_batch_update_edges(batch) edge_update_end = time.time() - logger.info(f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)") + logger.info( + f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)" + ) # 检查需要删除的边 memory_edge_keys = {(source, target) for source, target, _ in memory_edges} @@ -1157,7 +1186,9 @@ class EntorhinalCortex: if edges_to_delete: await self._async_delete_edges(edges_to_delete) edge_delete_end = time.time() - logger.info(f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)") + logger.info( + f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)" + ) end_time = time.time() logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒") @@ -1183,8 +1214,8 @@ class EntorhinalCortex: """异步批量更新节点""" try: for node_data in nodes_data: - GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where( - GraphNodes.concept == node_data['concept'] + GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where( + GraphNodes.concept == node_data["concept"] ).execute() except Exception as e: logger.error(f"[同步] 批量更新节点失败: {e}") @@ -1193,8 +1224,8 @@ class EntorhinalCortex: async def _async_update_node(self, node_data): """异步更新单个节点""" try: - GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where( - GraphNodes.concept == node_data['concept'] + GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where( + GraphNodes.concept == node_data["concept"] ).execute() except Exception as e: logger.error(f"[同步] 更新节点失败: {e}") @@ -1220,9 +1251,8 @@ class EntorhinalCortex: """异步批量更新边""" try: for edge_data in edges_data: - GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ['source', 'target']}).where( - (GraphEdges.source == edge_data['source']) & - (GraphEdges.target == edge_data['target']) + GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where( + (GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"]) ).execute() except Exception as e: logger.error(f"[同步] 批量更新边失败: {e}") @@ -1232,10 +1262,7 @@ class EntorhinalCortex: """异步删除边""" try: for source, target in edge_keys: - GraphEdges.delete().where( - (GraphEdges.source == source) & - (GraphEdges.target == target) - ).execute() + GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute() except Exception as e: logger.error(f"[同步] 删除边失败: {e}") raise @@ -1406,7 +1433,7 @@ class ParahippocampalGyrus: if not input_text: logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。") return set(), {} - + current_YMD_time = datetime.datetime.now().strftime("%Y-%m-%d") current_YMD_time_str = f"当前日期: {current_YMD_time}" input_text = f"{current_YMD_time_str}\n{input_text}" @@ -1533,8 +1560,7 @@ class ParahippocampalGyrus: logger.debug(f"连接同批次节点: {topic1} 和 {topic2}") all_added_edges.append(f"{topic1}-{topic2}") self.memory_graph.connect_dot(topic1, topic2) - - + progress = (i / len(memory_samples)) * 100 bar_length = 30 filled_length = int(bar_length * i // len(memory_samples)) diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index cb13721c1..8cb3b7f5c 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -79,7 +79,7 @@ class NormalChat: # 初始化Normal Chat专用表达器 self.expressor = NormalChatExpressor(self.chat_stream, self.stream_name) self.replyer = DefaultReplyer(chat_id=self.stream_id) - + self.replyer.chat_stream = self.chat_stream self._initialized = True @@ -243,9 +243,7 @@ class NormalChat: self.interest_dict.pop(msg_id, None) # 改为实例方法, 移除 chat 参数 - async def normal_response( - self, message: MessageRecv, is_mentioned: bool, interested_rate: float - ) -> None: + async def normal_response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None: # 新增:如果已停用,直接返回 if self._disabled: logger.info(f"[{self.stream_name}] 已停用,忽略 normal_response。") @@ -287,7 +285,6 @@ class NormalChat: # 回复前处理 await willing_manager.before_generate_reply_handle(message.message_info.message_id) - thinking_id = await self._create_thinking_message(message) logger.debug(f"[{self.stream_name}] 创建捕捉器,thinking_id:{thinking_id}") @@ -362,6 +359,9 @@ class NormalChat: if action_type == "no_action": logger.debug(f"[{self.stream_name}] Planner决定不执行任何额外动作") return None + elif action_type == "change_to_focus_chat": + logger.info(f"[{self.stream_name}] Planner决定切换到focus聊天模式") + return None # 执行额外的动作(不影响回复生成) action_result = await self._execute_action(action_type, action_data, message, thinking_id) @@ -396,7 +396,9 @@ class NormalChat: elif plan_result: logger.debug(f"[{self.stream_name}] 额外动作处理完成: {plan_result['action_type']}") - if not response_set or (self.enable_planner and self.action_type != "no_action"): + if not response_set or ( + self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"] + ): logger.info(f"[{self.stream_name}] 模型未生成回复内容") # 如果模型未生成回复,移除思考消息 container = await message_manager.get_container(self.stream_id) # 使用 self.stream_id @@ -445,7 +447,15 @@ class NormalChat: # 检查是否需要切换到focus模式 if global_config.chat.chat_mode == "auto": - await self._check_switch_to_focus() + if self.action_type == "change_to_focus_chat": + logger.info(f"[{self.stream_name}] 检测到切换到focus聊天模式的请求") + if self.on_switch_to_focus_callback: + await self.on_switch_to_focus_callback() + else: + logger.warning(f"[{self.stream_name}] 没有设置切换到focus聊天模式的回调函数,无法执行切换") + return + else: + await self._check_switch_to_focus() info_catcher.done_catch() diff --git a/src/chat/normal_chat/normal_chat_planner.py b/src/chat/normal_chat/normal_chat_planner.py index 9c8a70221..4fd6978b9 100644 --- a/src/chat/normal_chat/normal_chat_planner.py +++ b/src/chat/normal_chat/normal_chat_planner.py @@ -28,6 +28,7 @@ def init_prompt(): 重要说明: - "no_action" 表示只进行普通聊天回复,不执行任何额外动作 +- "change_to_focus_chat" 表示当聊天变得热烈、自己回复条数很多或需要深入交流时,正常回复消息并切换到focus_chat模式进行更深入的对话 - 其他action表示在普通回复的基础上,执行相应的额外动作 你必须从上面列出的可用action中选择一个,并说明原因。 @@ -156,8 +157,8 @@ class NormalChatPlanner: # 提取其他参数作为action_data action_data = {k: v for k, v in action_result.items() if k not in ["action", "reasoning"]} - # 验证动作是否在可用动作列表中 - if action not in current_available_actions: + # 验证动作是否在可用动作列表中,或者是特殊动作 + if action not in current_available_actions and action != "change_to_focus_chat": logger.warning(f"{self.log_prefix}规划器选择了不可用的动作: {action}, 回退到no_action") action = "no_action" reasoning = f"选择的动作{action}不在可用列表中,回退到no_action" @@ -211,6 +212,19 @@ class NormalChatPlanner: try: # 构建动作选项文本 action_options_text = "" + + # 添加特殊的change_to_focus_chat动作 + action_options_text += "action_name: change_to_focus_chat\n" + action_options_text += ( + " 描述:当聊天变得热烈、自己回复条数很多或需要深入交流时使用,正常回复消息并切换到focus_chat模式\n" + ) + action_options_text += " 参数:\n" + action_options_text += " 动作要求:\n" + action_options_text += " - 聊天上下文中自己的回复条数较多(超过3-4条)\n" + action_options_text += " - 对话进行得非常热烈活跃\n" + action_options_text += " - 用户表现出深入交流的意图\n" + action_options_text += " - 话题需要更专注和深入的讨论\n\n" + for action_name, action_info in current_available_actions.items(): action_description = action_info.get("description", "") action_parameters = action_info.get("parameters", {}) diff --git a/src/common/remote.py b/src/common/remote.py index b61a43d8c..49c314f8d 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -142,7 +142,9 @@ class TelemetryHeartBeatTask(AsyncTask): del local_storage["mmc_uuid"] # 删除本地存储的UUID else: # 其他错误 - logger.warning(f"(此错误不会影响正常使用)状态未发送,状态码: {response.status_code}, 响应内容: {response.text}") + logger.warning( + f"(此错误不会影响正常使用)状态未发送,状态码: {response.status_code}, 响应内容: {response.text}" + ) async def run(self): # 发送心跳 diff --git a/src/individuality/expression_style.py b/src/individuality/expression_style.py index 0d650ce46..77438d330 100644 --- a/src/individuality/expression_style.py +++ b/src/individuality/expression_style.py @@ -104,7 +104,13 @@ class PersonalityExpression: if count >= self.max_calculations: logger.debug(f"对于风格 '{current_style_text}' 已达到最大计算次数 ({self.max_calculations})。跳过提取。") # 即使跳过,也更新元数据以反映当前风格已被识别且计数已满 - self._write_meta_data({"last_style_text": current_style_text, "count": count, "last_update_time": meta_data.get("last_update_time")}) + self._write_meta_data( + { + "last_style_text": current_style_text, + "count": count, + "last_update_time": meta_data.get("last_update_time"), + } + ) return # 构建prompt @@ -119,12 +125,18 @@ class PersonalityExpression: except Exception as e: logger.error(f"个性表达方式提取失败: {e}") # 如果提取失败,保存当前的风格和未增加的计数 - self._write_meta_data({"last_style_text": current_style_text, "count": count, "last_update_time": meta_data.get("last_update_time")}) + self._write_meta_data( + { + "last_style_text": current_style_text, + "count": count, + "last_update_time": meta_data.get("last_update_time"), + } + ) return logger.info(f"个性表达方式提取response: {response}") # chat_id用personality - + # 转为dict并count=100 if response != "": expressions = self.parse_expression_response(response, "personality") @@ -136,25 +148,27 @@ class PersonalityExpression: existing_expressions = json.load(f) except (json.JSONDecodeError, FileNotFoundError): logger.warning(f"无法读取或解析 {self.expressions_file_path},将创建新的表达文件。") - + # 创建新的表达方式 new_expressions = [] for _, situation, style in expressions: new_expressions.append({"situation": situation, "style": style, "count": 1}) - + # 合并表达方式,如果situation和style相同则累加count merged_expressions = existing_expressions.copy() for new_expr in new_expressions: found = False for existing_expr in merged_expressions: - if (existing_expr["situation"] == new_expr["situation"] and - existing_expr["style"] == new_expr["style"]): + if ( + existing_expr["situation"] == new_expr["situation"] + and existing_expr["style"] == new_expr["style"] + ): existing_expr["count"] += new_expr["count"] found = True break if not found: merged_expressions.append(new_expr) - + # 超过50条时随机删除多余的,只保留50条 if len(merged_expressions) > 50: remove_count = len(merged_expressions) - 50 @@ -168,11 +182,9 @@ class PersonalityExpression: # 成功提取后更新元数据 count += 1 current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - self._write_meta_data({ - "last_style_text": current_style_text, - "count": count, - "last_update_time": current_time - }) + self._write_meta_data( + {"last_style_text": current_style_text, "count": count, "last_update_time": current_time} + ) logger.info(f"成功处理。风格 '{current_style_text}' 的计数现在是 {count},最后更新时间:{current_time}。") else: logger.warning(f"个性表达方式提取失败,模型返回空内容: {response}")