diff --git a/scripts/analyze_group_similarity.py b/scripts/analyze_group_similarity.py new file mode 100644 index 000000000..86863a52c --- /dev/null +++ b/scripts/analyze_group_similarity.py @@ -0,0 +1,180 @@ +import json +import os +from pathlib import Path +import numpy as np +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +import matplotlib.pyplot as plt +import seaborn as sns +import networkx as nx +import matplotlib as mpl +import sqlite3 + +# 设置中文字体 +plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 使用微软雅黑 +plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 +plt.rcParams['font.family'] = 'sans-serif' + +# 获取脚本所在目录 +SCRIPT_DIR = Path(__file__).parent + +def get_group_name(stream_id): + """从数据库中获取群组名称""" + conn = sqlite3.connect('data/maibot.db') + cursor = conn.cursor() + + cursor.execute(''' + SELECT group_name, user_nickname, platform + FROM chat_streams + WHERE stream_id = ? + ''', (stream_id,)) + + result = cursor.fetchone() + conn.close() + + if result: + group_name, user_nickname, platform = result + if group_name: + return group_name + if user_nickname: + return user_nickname + if platform: + return f"{platform}-{stream_id[:8]}" + return stream_id + +def load_group_data(group_dir): + """加载单个群组的数据""" + json_path = Path(group_dir) / "expressions.json" + if not json_path.exists(): + return [], [], [] + + with open(json_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + situations = [] + styles = [] + combined = [] + + for item in data: + count = item['count'] + situations.extend([item['situation']] * count) + styles.extend([item['style']] * count) + combined.extend([f"{item['situation']} {item['style']}"] * count) + + return situations, styles, combined + +def analyze_group_similarity(): + # 获取所有群组目录 + base_dir = Path("data/expression/learnt_style") + group_dirs = [d for d in base_dir.iterdir() if d.is_dir()] + group_ids = [d.name for d in group_dirs] + + # 获取群组名称 + group_names = [get_group_name(group_id) for group_id in group_ids] + + # 加载所有群组的数据 + group_situations = [] + group_styles = [] + group_combined = [] + + for d in group_dirs: + situations, styles, combined = load_group_data(d) + group_situations.append(' '.join(situations)) + group_styles.append(' '.join(styles)) + group_combined.append(' '.join(combined)) + + # 创建TF-IDF向量化器 + vectorizer = TfidfVectorizer() + + # 计算三种相似度矩阵 + situation_matrix = cosine_similarity(vectorizer.fit_transform(group_situations)) + style_matrix = cosine_similarity(vectorizer.fit_transform(group_styles)) + combined_matrix = cosine_similarity(vectorizer.fit_transform(group_combined)) + + # 对相似度矩阵进行对数变换 + log_situation_matrix = np.log1p(situation_matrix) + log_style_matrix = np.log1p(style_matrix) + log_combined_matrix = np.log1p(combined_matrix) + + # 创建一个大图,包含三个子图 + plt.figure(figsize=(45, 12)) + + # 场景相似度热力图 + plt.subplot(1, 3, 1) + sns.heatmap(log_situation_matrix, + xticklabels=group_names, + yticklabels=group_names, + cmap='YlOrRd', + annot=True, + fmt='.2f', + vmin=0, + vmax=np.log1p(0.2)) + plt.title('群组场景相似度热力图 (对数变换)') + plt.xticks(rotation=45, ha='right') + + # 表达方式相似度热力图 + plt.subplot(1, 3, 2) + sns.heatmap(log_style_matrix, + xticklabels=group_names, + yticklabels=group_names, + cmap='YlOrRd', + annot=True, + fmt='.2f', + vmin=0, + vmax=np.log1p(0.2)) + plt.title('群组表达方式相似度热力图 (对数变换)') + plt.xticks(rotation=45, ha='right') + + # 组合相似度热力图 + plt.subplot(1, 3, 3) + sns.heatmap(log_combined_matrix, + xticklabels=group_names, + yticklabels=group_names, + cmap='YlOrRd', + annot=True, + fmt='.2f', + vmin=0, + vmax=np.log1p(0.2)) + plt.title('群组场景+表达方式相似度热力图 (对数变换)') + plt.xticks(rotation=45, ha='right') + + plt.tight_layout() + plt.savefig(SCRIPT_DIR / 'group_similarity_heatmaps.png', dpi=300, bbox_inches='tight') + plt.close() + + # 保存匹配详情到文本文件 + with open(SCRIPT_DIR / 'group_similarity_details.txt', 'w', encoding='utf-8') as f: + f.write('群组相似度详情\n') + f.write('=' * 50 + '\n\n') + + for i in range(len(group_ids)): + for j in range(i+1, len(group_ids)): + if log_combined_matrix[i][j] > np.log1p(0.05): + f.write(f'群组1: {group_names[i]}\n') + f.write(f'群组2: {group_names[j]}\n') + f.write(f'场景相似度: {situation_matrix[i][j]:.4f}\n') + f.write(f'表达方式相似度: {style_matrix[i][j]:.4f}\n') + f.write(f'组合相似度: {combined_matrix[i][j]:.4f}\n') + + # 获取两个群组的数据 + situations1, styles1, _ = load_group_data(group_dirs[i]) + situations2, styles2, _ = load_group_data(group_dirs[j]) + + # 找出共同的场景 + common_situations = set(situations1) & set(situations2) + if common_situations: + f.write('\n共同场景:\n') + for situation in common_situations: + f.write(f'- {situation}\n') + + # 找出共同的表达方式 + common_styles = set(styles1) & set(styles2) + if common_styles: + f.write('\n共同表达方式:\n') + for style in common_styles: + f.write(f'- {style}\n') + + f.write('\n' + '-' * 50 + '\n\n') + +if __name__ == "__main__": + analyze_group_similarity() diff --git a/scripts/group_similarity_heatmap.png b/scripts/group_similarity_heatmap.png new file mode 100644 index 000000000..217b3a0a5 Binary files /dev/null and b/scripts/group_similarity_heatmap.png differ diff --git a/scripts/group_similarity_network.png b/scripts/group_similarity_network.png new file mode 100644 index 000000000..fdcc816e9 Binary files /dev/null and b/scripts/group_similarity_network.png differ diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py index c6d2950fd..edd27e435 100644 --- a/scripts/mongodb_to_sqlite.py +++ b/scripts/mongodb_to_sqlite.py @@ -182,25 +182,6 @@ class MongoToSQLiteMigrator: enable_validation=False, # 禁用数据验证 unique_fields=["stream_id"], ), - # LLM使用记录迁移配置 - MigrationConfig( - mongo_collection="llm_usage", - target_model=LLMUsage, - field_mapping={ - "model_name": "model_name", - "user_id": "user_id", - "request_type": "request_type", - "endpoint": "endpoint", - "prompt_tokens": "prompt_tokens", - "completion_tokens": "completion_tokens", - "total_tokens": "total_tokens", - "cost": "cost", - "status": "status", - "timestamp": "timestamp", - }, - enable_validation=True, # 禁用数据验证" - unique_fields=["user_id", "prompt_tokens", "completion_tokens", "total_tokens", "cost"], # 组合唯一性 - ), # 消息迁移配置 MigrationConfig( mongo_collection="messages", diff --git a/scripts/preview_expressions.py b/scripts/preview_expressions.py new file mode 100644 index 000000000..0eebfb442 --- /dev/null +++ b/scripts/preview_expressions.py @@ -0,0 +1,265 @@ +import tkinter as tk +from tkinter import ttk +import json +import os +from pathlib import Path +import networkx as nx +import matplotlib.pyplot as plt +from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +import numpy as np +from collections import defaultdict + +class ExpressionViewer: + def __init__(self, root): + self.root = root + self.root.title("表达方式预览器") + self.root.geometry("1200x800") + + # 创建主框架 + self.main_frame = ttk.Frame(root) + self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) + + # 创建左侧控制面板 + self.control_frame = ttk.Frame(self.main_frame) + self.control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 10)) + + # 创建搜索框 + self.search_frame = ttk.Frame(self.control_frame) + self.search_frame.pack(fill=tk.X, pady=(0, 10)) + + self.search_var = tk.StringVar() + self.search_var.trace('w', self.filter_expressions) + self.search_entry = ttk.Entry(self.search_frame, textvariable=self.search_var) + self.search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True) + ttk.Label(self.search_frame, text="搜索:").pack(side=tk.LEFT, padx=(0, 5)) + + # 创建文件选择下拉框 + self.file_var = tk.StringVar() + self.file_combo = ttk.Combobox(self.search_frame, textvariable=self.file_var) + self.file_combo.pack(side=tk.LEFT, padx=5) + self.file_combo.bind('<>', self.load_file) + + # 创建排序选项 + self.sort_frame = ttk.LabelFrame(self.control_frame, text="排序选项") + self.sort_frame.pack(fill=tk.X, pady=5) + + self.sort_var = tk.StringVar(value="count") + ttk.Radiobutton(self.sort_frame, text="按计数排序", variable=self.sort_var, + value="count", command=self.apply_sort).pack(anchor=tk.W) + ttk.Radiobutton(self.sort_frame, text="按情境排序", variable=self.sort_var, + value="situation", command=self.apply_sort).pack(anchor=tk.W) + ttk.Radiobutton(self.sort_frame, text="按风格排序", variable=self.sort_var, + value="style", command=self.apply_sort).pack(anchor=tk.W) + + # 创建分群选项 + self.group_frame = ttk.LabelFrame(self.control_frame, text="分群选项") + self.group_frame.pack(fill=tk.X, pady=5) + + self.group_var = tk.StringVar(value="none") + ttk.Radiobutton(self.group_frame, text="不分群", variable=self.group_var, + value="none", command=self.apply_grouping).pack(anchor=tk.W) + ttk.Radiobutton(self.group_frame, text="按情境分群", variable=self.group_var, + value="situation", command=self.apply_grouping).pack(anchor=tk.W) + ttk.Radiobutton(self.group_frame, text="按风格分群", variable=self.group_var, + value="style", command=self.apply_grouping).pack(anchor=tk.W) + + # 创建相似度阈值滑块 + self.similarity_frame = ttk.LabelFrame(self.control_frame, text="相似度设置") + self.similarity_frame.pack(fill=tk.X, pady=5) + + self.similarity_var = tk.DoubleVar(value=0.5) + self.similarity_scale = ttk.Scale(self.similarity_frame, from_=0.0, to=1.0, + variable=self.similarity_var, orient=tk.HORIZONTAL, + command=self.update_similarity) + self.similarity_scale.pack(fill=tk.X, padx=5, pady=5) + ttk.Label(self.similarity_frame, text="相似度阈值: 0.5").pack() + + # 创建显示选项 + self.view_frame = ttk.LabelFrame(self.control_frame, text="显示选项") + self.view_frame.pack(fill=tk.X, pady=5) + + self.show_graph_var = tk.BooleanVar(value=True) + ttk.Checkbutton(self.view_frame, text="显示关系图", variable=self.show_graph_var, + command=self.toggle_graph).pack(anchor=tk.W) + + # 创建右侧内容区域 + self.content_frame = ttk.Frame(self.main_frame) + self.content_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) + + # 创建文本显示区域 + self.text_area = tk.Text(self.content_frame, wrap=tk.WORD) + self.text_area.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + + # 添加滚动条 + scrollbar = ttk.Scrollbar(self.text_area, command=self.text_area.yview) + scrollbar.pack(side=tk.RIGHT, fill=tk.Y) + self.text_area.config(yscrollcommand=scrollbar.set) + + # 创建图形显示区域 + self.graph_frame = ttk.Frame(self.content_frame) + self.graph_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True) + + # 初始化数据 + self.current_data = [] + self.graph = nx.Graph() + self.canvas = None + + # 加载文件列表 + self.load_file_list() + + def load_file_list(self): + expression_dir = Path("data/expression") + files = [] + for root, _, filenames in os.walk(expression_dir): + for filename in filenames: + if filename.endswith('.json'): + rel_path = os.path.relpath(os.path.join(root, filename), expression_dir) + files.append(rel_path) + + self.file_combo['values'] = files + if files: + self.file_combo.set(files[0]) + self.load_file(None) + + def load_file(self, event): + selected_file = self.file_var.get() + if not selected_file: + return + + file_path = os.path.join("data/expression", selected_file) + try: + with open(file_path, 'r', encoding='utf-8') as f: + self.current_data = json.load(f) + + self.apply_sort() + self.update_similarity() + + except Exception as e: + self.text_area.delete(1.0, tk.END) + self.text_area.insert(tk.END, f"加载文件时出错: {str(e)}") + + def apply_sort(self): + if not self.current_data: + return + + sort_key = self.sort_var.get() + reverse = sort_key == "count" + + self.current_data.sort(key=lambda x: x.get(sort_key, ""), reverse=reverse) + self.apply_grouping() + + def apply_grouping(self): + if not self.current_data: + return + + group_key = self.group_var.get() + if group_key == "none": + self.display_data(self.current_data) + return + + grouped_data = defaultdict(list) + for item in self.current_data: + key = item.get(group_key, "未分类") + grouped_data[key].append(item) + + self.text_area.delete(1.0, tk.END) + for group, items in grouped_data.items(): + self.text_area.insert(tk.END, f"\n=== {group} ===\n\n") + for item in items: + self.text_area.insert(tk.END, f"情境: {item.get('situation', 'N/A')}\n") + self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n") + self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n") + self.text_area.insert(tk.END, "-" * 50 + "\n") + + def display_data(self, data): + self.text_area.delete(1.0, tk.END) + for item in data: + self.text_area.insert(tk.END, f"情境: {item.get('situation', 'N/A')}\n") + self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n") + self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n") + self.text_area.insert(tk.END, "-" * 50 + "\n") + + def update_similarity(self, *args): + if not self.current_data: + return + + threshold = self.similarity_var.get() + self.similarity_frame.winfo_children()[-1].config(text=f"相似度阈值: {threshold:.2f}") + + # 计算相似度 + texts = [f"{item['situation']} {item['style']}" for item in self.current_data] + vectorizer = TfidfVectorizer() + tfidf_matrix = vectorizer.fit_transform(texts) + similarity_matrix = cosine_similarity(tfidf_matrix) + + # 创建图 + self.graph.clear() + for i, item in enumerate(self.current_data): + self.graph.add_node(i, label=f"{item['situation']}\n{item['style']}") + + # 添加边 + for i in range(len(self.current_data)): + for j in range(i + 1, len(self.current_data)): + if similarity_matrix[i, j] > threshold: + self.graph.add_edge(i, j, weight=similarity_matrix[i, j]) + + if self.show_graph_var.get(): + self.draw_graph() + + def draw_graph(self): + if self.canvas: + self.canvas.get_tk_widget().destroy() + + fig = plt.figure(figsize=(8, 6)) + pos = nx.spring_layout(self.graph) + + # 绘制节点 + nx.draw_networkx_nodes(self.graph, pos, node_color='lightblue', + node_size=1000, alpha=0.6) + + # 绘制边 + nx.draw_networkx_edges(self.graph, pos, alpha=0.4) + + # 添加标签 + labels = nx.get_node_attributes(self.graph, 'label') + nx.draw_networkx_labels(self.graph, pos, labels, font_size=8) + + plt.title("表达方式关系图") + plt.axis('off') + + self.canvas = FigureCanvasTkAgg(fig, master=self.graph_frame) + self.canvas.draw() + self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True) + + def toggle_graph(self): + if self.show_graph_var.get(): + self.draw_graph() + else: + if self.canvas: + self.canvas.get_tk_widget().destroy() + self.canvas = None + + def filter_expressions(self, *args): + search_text = self.search_var.get().lower() + if not search_text: + self.apply_sort() + return + + filtered_data = [] + for item in self.current_data: + situation = item.get('situation', '').lower() + style = item.get('style', '').lower() + if search_text in situation or search_text in style: + filtered_data.append(item) + + self.display_data(filtered_data) + +def main(): + root = tk.Tk() + app = ExpressionViewer(root) + root.mainloop() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/chat/focus_chat/info_processors/working_memory_processor.py b/src/chat/focus_chat/info_processors/working_memory_processor.py index e13cddbe4..a7e6ae6cd 100644 --- a/src/chat/focus_chat/info_processors/working_memory_processor.py +++ b/src/chat/focus_chat/info_processors/working_memory_processor.py @@ -162,7 +162,7 @@ class WorkingMemoryProcessor(BaseProcessor): memory_brief = memory_summary.get("brief") memory_points = memory_summary.get("points", []) for point in memory_points: - memory_str += f"记忆要点:{point}\n" + memory_str += f"{point}\n" working_memory_info = WorkingMemoryInfo() if memory_str: diff --git a/src/chat/focus_chat/planners/action_manager.py b/src/chat/focus_chat/planners/action_manager.py index dc5343733..fc6f567e2 100644 --- a/src/chat/focus_chat/planners/action_manager.py +++ b/src/chat/focus_chat/planners/action_manager.py @@ -135,11 +135,11 @@ class ActionManager: cycle_timers: dict, thinking_id: str, observations: List[Observation], - expressor: DefaultExpressor, - replyer: DefaultReplyer, chat_stream: ChatStream, log_prefix: str, shutting_down: bool = False, + expressor: DefaultExpressor = None, + replyer: DefaultReplyer = None, ) -> Optional[BaseAction]: """ 创建动作处理器实例 diff --git a/src/chat/focus_chat/planners/actions/plugin_action.py b/src/chat/focus_chat/planners/actions/plugin_action.py index b450e9bb4..f68aa5834 100644 --- a/src/chat/focus_chat/planners/actions/plugin_action.py +++ b/src/chat/focus_chat/planners/actions/plugin_action.py @@ -182,26 +182,33 @@ class PluginAction(BaseAction): Returns: bool: 是否发送成功 """ - try: - expressor = self._services.get("expressor") - chat_stream = self._services.get("chat_stream") + expressor = self._services.get("expressor") + chat_stream = self._services.get("chat_stream") - if not expressor or not chat_stream: - logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") - return False + if not expressor or not chat_stream: + logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务") + return False - # 构造简化的动作数据 - reply_data = {"text": text, "target": target or "", "emojis": []} + # 构造简化的动作数据 + reply_data = {"text": text, "target": target or "", "emojis": []} - # 获取锚定消息(如果有) - observations = self._services.get("observations", []) + # 获取锚定消息(如果有) + observations = self._services.get("observations", []) + + # 查找 ChattingObservation 实例 + chatting_observation = None + for obs in observations: + if isinstance(obs, ChattingObservation): + chatting_observation = obs + break - chatting_observation: ChattingObservation = next( - obs for obs in observations if isinstance(obs, ChattingObservation) + if not chatting_observation: + logger.warning(f"{self.log_prefix} 未找到 ChattingObservation 实例,创建占位符") + anchor_message = await create_empty_anchor_message( + chat_stream.platform, chat_stream.group_info, chat_stream ) + else: anchor_message = chatting_observation.search_message_by_text(reply_data["target"]) - - # 如果没有找到锚点消息,创建一个占位符 if not anchor_message: logger.info(f"{self.log_prefix} 未找到锚点消息,创建占位符") anchor_message = await create_empty_anchor_message( @@ -210,19 +217,16 @@ class PluginAction(BaseAction): else: anchor_message.update_chat_stream(chat_stream) - # 调用内部方法发送消息 - success, _ = await expressor.deal_reply( - cycle_timers=self.cycle_timers, - action_data=reply_data, - anchor_message=anchor_message, - reasoning=self.reasoning, - thinking_id=self.thinking_id, - ) + # 调用内部方法发送消息 + success, _ = await expressor.deal_reply( + cycle_timers=self.cycle_timers, + action_data=reply_data, + anchor_message=anchor_message, + reasoning=self.reasoning, + thinking_id=self.thinking_id, + ) - return success - except Exception as e: - logger.error(f"{self.log_prefix} 发送消息时出错: {e}") - return False + return success def get_chat_type(self) -> str: """获取当前聊天类型 diff --git a/src/chat/focus_chat/working_memory/memory_manager.py b/src/chat/focus_chat/working_memory/memory_manager.py index 27f7ad0b1..21aa94707 100644 --- a/src/chat/focus_chat/working_memory/memory_manager.py +++ b/src/chat/focus_chat/working_memory/memory_manager.py @@ -286,110 +286,110 @@ class MemoryManager: logger.error(f"生成总结时出错: {str(e)}") return default_summary - async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]: - """ - 对记忆进行精简操作,根据要求修改要点、总结和概括 +# async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]: +# """ +# 对记忆进行精简操作,根据要求修改要点、总结和概括 - Args: - memory_id: 记忆ID - requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点 +# Args: +# memory_id: 记忆ID +# requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点 - Returns: - 修改后的记忆总结字典 - """ - # 获取指定ID的记忆项 - logger.info(f"精简记忆: {memory_id}") - memory_item = self.get_by_id(memory_id) - if not memory_item: - raise ValueError(f"未找到ID为{memory_id}的记忆项") +# Returns: +# 修改后的记忆总结字典 +# """ +# # 获取指定ID的记忆项 +# logger.info(f"精简记忆: {memory_id}") +# memory_item = self.get_by_id(memory_id) +# if not memory_item: +# raise ValueError(f"未找到ID为{memory_id}的记忆项") - # 增加精简次数 - memory_item.increase_compress_count() +# # 增加精简次数 +# memory_item.increase_compress_count() - summary = memory_item.summary +# summary = memory_item.summary - # 使用LLM根据要求对总结、概括和要点进行精简修改 - prompt = f""" -请根据以下要求,对记忆内容的主题和关键要点进行精简,模拟记忆的遗忘过程: -要求:{requirements} -你可以随机对关键要点进行压缩,模糊或者丢弃,修改后,同样修改主题 +# # 使用LLM根据要求对总结、概括和要点进行精简修改 +# prompt = f""" +# 请根据以下要求,对记忆内容的主题和关键要点进行精简,模拟记忆的遗忘过程: +# 要求:{requirements} +# 你可以随机对关键要点进行压缩,模糊或者丢弃,修改后,同样修改主题 -目前主题:{summary["brief"]} +# 目前主题:{summary["brief"]} -目前关键要点: -{chr(10).join([f"- {point}" for point in summary.get("points", [])])} +# 目前关键要点: +# {chr(10).join([f"- {point}" for point in summary.get("points", [])])} -请生成修改后的主题和关键要点,遵循以下格式: -```json -{{ - "brief": "修改后的主题(20字以内)", - "points": [ - "修改后的要点", - "修改后的要点" - ] -}} -``` -请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 -""" - # 定义默认的精简结果 - default_refined = { - "brief": summary["brief"], - "points": summary.get("points", ["未知的要点"])[:1], # 默认只保留第一个要点 - } +# 请生成修改后的主题和关键要点,遵循以下格式: +# ```json +# {{ +# "brief": "修改后的主题(20字以内)", +# "points": [ +# "修改后的要点", +# "修改后的要点" +# ] +# }} +# ``` +# 请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。 +# """ +# # 定义默认的精简结果 +# default_refined = { +# "brief": summary["brief"], +# "points": summary.get("points", ["未知的要点"])[:1], # 默认只保留第一个要点 +# } - try: - # 调用LLM修改总结、概括和要点 - response, _ = await self.llm_summarizer.generate_response_async(prompt) - logger.debug(f"精简记忆响应: {response}") - # 使用repair_json处理响应 - try: - # 修复JSON格式 - fixed_json_string = repair_json(response) +# try: +# # 调用LLM修改总结、概括和要点 +# response, _ = await self.llm_summarizer.generate_response_async(prompt) +# logger.debug(f"精简记忆响应: {response}") +# # 使用repair_json处理响应 +# try: +# # 修复JSON格式 +# fixed_json_string = repair_json(response) - # 将修复后的字符串解析为Python对象 - if isinstance(fixed_json_string, str): - try: - refined_data = json.loads(fixed_json_string) - except json.JSONDecodeError as decode_error: - logger.error(f"JSON解析错误: {str(decode_error)}") - refined_data = default_refined - else: - # 如果repair_json直接返回了字典对象,直接使用 - refined_data = fixed_json_string +# # 将修复后的字符串解析为Python对象 +# if isinstance(fixed_json_string, str): +# try: +# refined_data = json.loads(fixed_json_string) +# except json.JSONDecodeError as decode_error: +# logger.error(f"JSON解析错误: {str(decode_error)}") +# refined_data = default_refined +# else: +# # 如果repair_json直接返回了字典对象,直接使用 +# refined_data = fixed_json_string - # 确保是字典类型 - if not isinstance(refined_data, dict): - logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}") - refined_data = default_refined +# # 确保是字典类型 +# if not isinstance(refined_data, dict): +# logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}") +# refined_data = default_refined - # 更新总结 - summary["brief"] = refined_data.get("brief", "主题未知的记忆") +# # 更新总结 +# summary["brief"] = refined_data.get("brief", "主题未知的记忆") - # 更新关键要点 - points = refined_data.get("points", []) - if isinstance(points, list) and points: - # 确保所有要点都是字符串 - summary["points"] = [str(point) for point in points if point is not None] - else: - # 如果points不是列表或为空,使用默认值 - summary["points"] = ["主要要点已遗忘"] +# # 更新关键要点 +# points = refined_data.get("points", []) +# if isinstance(points, list) and points: +# # 确保所有要点都是字符串 +# summary["points"] = [str(point) for point in points if point is not None] +# else: +# # 如果points不是列表或为空,使用默认值 +# summary["points"] = ["主要要点已遗忘"] - except Exception as e: - logger.error(f"精简记忆出错: {str(e)}") - traceback.print_exc() +# except Exception as e: +# logger.error(f"精简记忆出错: {str(e)}") +# traceback.print_exc() - # 出错时使用简化的默认精简 - summary["brief"] = summary["brief"] + " (已简化)" - summary["points"] = summary.get("points", ["未知的要点"])[:1] +# # 出错时使用简化的默认精简 +# summary["brief"] = summary["brief"] + " (已简化)" +# summary["points"] = summary.get("points", ["未知的要点"])[:1] - except Exception as e: - logger.error(f"精简记忆调用LLM出错: {str(e)}") - traceback.print_exc() +# except Exception as e: +# logger.error(f"精简记忆调用LLM出错: {str(e)}") +# traceback.print_exc() - # 更新原记忆项的总结 - memory_item.set_summary(summary) +# # 更新原记忆项的总结 +# memory_item.set_summary(summary) - return memory_item +# return memory_item def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool: """ diff --git a/src/chat/focus_chat/working_memory/working_memory.py b/src/chat/focus_chat/working_memory/working_memory.py index b06456a50..6f3510709 100644 --- a/src/chat/focus_chat/working_memory/working_memory.py +++ b/src/chat/focus_chat/working_memory/working_memory.py @@ -112,10 +112,10 @@ class WorkingMemory: self.memory_manager.delete(memory_id) continue # 计算衰减量 - if memory_item.memory_strength < 5: - await self.memory_manager.refine_memory( - memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩" - ) + # if memory_item.memory_strength < 5: + # await self.memory_manager.refine_memory( + # memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩" + # ) async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem: """合并记忆 @@ -127,51 +127,6 @@ class WorkingMemory: memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容" ) - # 暂时没用,先留着 - async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2): - """ - 模拟记忆模糊过程,随机选择一部分记忆进行精简 - - Args: - chat_id: 聊天ID - blur_rate: 模糊比率(0-1之间),表示有多少比例的记忆会被精简 - """ - memory = self.get_memory(chat_id) - - # 获取所有字符串类型且有总结的记忆 - all_summarized_memories = [] - for type_items in memory._memory.values(): - for item in type_items: - if isinstance(item.data, str) and hasattr(item, "summary") and item.summary: - all_summarized_memories.append(item) - - if not all_summarized_memories: - return - - # 计算要模糊的记忆数量 - blur_count = max(1, int(len(all_summarized_memories) * blur_rate)) - - # 随机选择要模糊的记忆 - memories_to_blur = random.sample(all_summarized_memories, min(blur_count, len(all_summarized_memories))) - - # 对选中的记忆进行精简 - for memory_item in memories_to_blur: - try: - # 根据记忆强度决定模糊程度 - if memory_item.memory_strength > 7: - requirement = "保留所有重要信息,仅略微精简" - elif memory_item.memory_strength > 4: - requirement = "保留核心要点,适度精简细节" - else: - requirement = "只保留最关键的1-2个要点,大幅精简内容" - - # 进行精简 - await memory.refine_memory(memory_item.id, requirement) - print(f"已模糊记忆 {memory_item.id},强度: {memory_item.memory_strength}, 要求: {requirement}") - - except Exception as e: - print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}") - async def shutdown(self) -> None: """关闭管理器,停止所有任务""" if self.decay_task and not self.decay_task.done(): diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index e63840f11..43cc1fef6 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -17,12 +17,14 @@ from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # from ..utils.chat_message_builder import ( get_raw_msg_by_timestamp, build_readable_messages, + get_raw_msg_by_timestamp_with_chat, ) # 导入 build_readable_messages from ..utils.utils import translate_timestamp_to_human_readable from rich.traceback import install from ...config.config import global_config from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入 +from peewee import Case install(extra_lines=3) @@ -215,15 +217,18 @@ class Hippocampus: """计算节点的特征值""" if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] - sorted_items = sorted(memory_items) - content = f"{concept}:{'|'.join(sorted_items)}" + + # 使用集合来去重,避免排序 + unique_items = set(str(item) for item in memory_items) + # 使用frozenset来保证顺序一致性 + content = f"{concept}:{frozenset(unique_items)}" return hash(content) @staticmethod def calculate_edge_hash(source, target) -> int: """计算边的特征值""" - nodes = sorted([source, target]) - return hash(f"{nodes[0]}:{nodes[1]}") + # 直接使用元组,保证顺序一致性 + return hash((source, target)) @staticmethod def find_topic_llm(text, topic_num): @@ -811,7 +816,8 @@ class EntorhinalCortex: timestamps = sample_scheduler.get_timestamp_array() # 使用 translate_timestamp_to_human_readable 并指定 mode="normal" readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps] - logger.info(f"回忆往事: {readable_timestamps}") + for timestamp, readable_timestamp in zip(timestamps, readable_timestamps): + logger.debug(f"回忆往事: {readable_timestamp}") chat_samples = [] for timestamp in timestamps: # 调用修改后的 random_get_msg_snippet @@ -820,10 +826,10 @@ class EntorhinalCortex: ) if messages: time_diff = (datetime.datetime.now().timestamp() - timestamp) / 3600 - logger.debug(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条") + logger.success(f"成功抽取 {time_diff:.1f} 小时前的消息样本,共{len(messages)}条") chat_samples.append(messages) else: - logger.debug(f"时间戳 {timestamp} 的消息样本抽取失败") + logger.debug(f"时间戳 {timestamp} 的消息无需记忆") return chat_samples @@ -837,32 +843,37 @@ class EntorhinalCortex: # 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds timestamp_start = target_timestamp timestamp_end = target_timestamp + time_window_seconds - - # 使用 chat_message_builder 的函数获取消息 - # limit_mode='earliest' 获取这个时间窗口内最早的 chat_size 条消息 - messages = get_raw_msg_by_timestamp( - timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest" + + chosen_message = get_raw_msg_by_timestamp( + timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest" ) + + if chosen_message: + chat_id = chosen_message[0].get("chat_id") - if messages: - # 检查获取到的所有消息是否都未达到最大记忆次数 - all_valid = True - for message in messages: - if message.get("memorized_times", 0) >= max_memorized_time_per_msg: - all_valid = False - break + messages = get_raw_msg_by_timestamp_with_chat( + timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest", chat_id=chat_id + ) - # 如果所有消息都有效 - if all_valid: - # 更新数据库中的记忆次数 + if messages: + # 检查获取到的所有消息是否都未达到最大记忆次数 + all_valid = True for message in messages: - # 确保在更新前获取最新的 memorized_times - current_memorized_times = message.get("memorized_times", 0) - # 使用 Peewee 更新记录 - Messages.update(memorized_times=current_memorized_times + 1).where( - Messages.message_id == message["message_id"] - ).execute() - return messages # 直接返回原始的消息列表 + if message.get("memorized_times", 0) >= max_memorized_time_per_msg: + all_valid = False + break + + # 如果所有消息都有效 + if all_valid: + # 更新数据库中的记忆次数 + for message in messages: + # 确保在更新前获取最新的 memorized_times + current_memorized_times = message.get("memorized_times", 0) + # 使用 Peewee 更新记录 + Messages.update(memorized_times=current_memorized_times + 1).where( + Messages.message_id == message["message_id"] + ).execute() + return messages # 直接返回原始的消息列表 # 如果获取失败或消息无效,增加尝试次数 try_count += 1 @@ -873,85 +884,361 @@ class EntorhinalCortex: async def sync_memory_to_db(self): """将记忆图同步到数据库""" + start_time = time.time() + # 获取数据库中所有节点和内存中所有节点 + db_load_start = time.time() db_nodes = {node.concept: node for node in GraphNodes.select()} memory_nodes = list(self.memory_graph.G.nodes(data=True)) + db_load_end = time.time() + logger.info(f"[同步] 加载数据库耗时: {db_load_end - db_load_start:.2f}秒") + + # 批量准备节点数据 + nodes_to_create = [] + nodes_to_update = [] + current_time = datetime.datetime.now().timestamp() # 检查并更新节点 + node_process_start = time.time() for concept, data in memory_nodes: + # 检查概念是否有效 + if not concept or not isinstance(concept, str): + logger.warning(f"[同步] 发现无效概念,将移除节点: {concept}") + # 从图中移除节点(这会自动移除相关的边) + self.memory_graph.G.remove_node(concept) + continue + memory_items = data.get("memory_items", []) if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] + # 检查记忆项是否为空 + if not memory_items: + logger.warning(f"[同步] 发现空记忆节点,将移除节点: {concept}") + # 从图中移除节点(这会自动移除相关的边) + self.memory_graph.G.remove_node(concept) + continue + # 计算内存中节点的特征值 memory_hash = self.hippocampus.calculate_node_hash(concept, memory_items) # 获取时间信息 - created_time = data.get("created_time", datetime.datetime.now().timestamp()) - last_modified = data.get("last_modified", datetime.datetime.now().timestamp()) + created_time = data.get("created_time", current_time) + last_modified = data.get("last_modified", current_time) # 将memory_items转换为JSON字符串 - memory_items_json = json.dumps(memory_items, ensure_ascii=False) + try: + # 确保memory_items中的每个项都是字符串 + memory_items = [str(item) for item in memory_items] + memory_items_json = json.dumps(memory_items, ensure_ascii=False) + if not memory_items_json: # 确保JSON字符串不为空 + raise ValueError("序列化后的JSON字符串为空") + # 验证JSON字符串是否有效 + json.loads(memory_items_json) + except Exception as e: + logger.error(f"[同步] 序列化记忆项失败,将移除节点: {concept}, 错误: {e}") + # 从图中移除节点(这会自动移除相关的边) + self.memory_graph.G.remove_node(concept) + continue if concept not in db_nodes: - # 数据库中缺少的节点,添加 - GraphNodes.create( - concept=concept, - memory_items=memory_items_json, - hash=memory_hash, - created_time=created_time, - last_modified=last_modified, - ) + # 数据库中缺少的节点,添加到创建列表 + nodes_to_create.append({ + 'concept': concept, + 'memory_items': memory_items_json, + 'hash': memory_hash, + 'created_time': created_time, + 'last_modified': last_modified + }) + logger.debug(f"[同步] 准备创建节点: {concept}, memory_items长度: {len(memory_items)}") else: # 获取数据库中节点的特征值 db_node = db_nodes[concept] db_hash = db_node.hash - # 如果特征值不同,则更新节点 + # 如果特征值不同,则添加到更新列表 if db_hash != memory_hash: - db_node.memory_items = memory_items_json - db_node.hash = memory_hash - db_node.last_modified = last_modified - db_node.save() + nodes_to_update.append({ + 'concept': concept, + 'memory_items': memory_items_json, + 'hash': memory_hash, + 'last_modified': last_modified + }) + + # 检查需要删除的节点 + memory_concepts = {concept for concept, _ in memory_nodes} + db_concepts = set(db_nodes.keys()) + nodes_to_delete = db_concepts - memory_concepts + + node_process_end = time.time() + logger.info(f"[同步] 处理节点数据耗时: {node_process_end - node_process_start:.2f}秒") + logger.info(f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点") + + # 异步批量创建新节点 + node_create_start = time.time() + if nodes_to_create: + try: + # 验证所有要创建的节点数据 + valid_nodes_to_create = [] + for node_data in nodes_to_create: + if not node_data.get('memory_items'): + logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空") + continue + try: + # 验证 JSON 字符串 + json.loads(node_data['memory_items']) + valid_nodes_to_create.append(node_data) + except json.JSONDecodeError: + logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串") + continue + + if valid_nodes_to_create: + # 使用异步批量插入 + batch_size = 100 + for i in range(0, len(valid_nodes_to_create), batch_size): + batch = valid_nodes_to_create[i:i + batch_size] + await self._async_batch_create_nodes(batch) + logger.info(f"[同步] 成功创建 {len(valid_nodes_to_create)} 个节点") + else: + logger.warning("[同步] 没有有效的节点可以创建") + except Exception as e: + logger.error(f"[同步] 创建节点失败: {e}") + # 尝试逐个创建以找出问题节点 + for node_data in nodes_to_create: + try: + if not node_data.get('memory_items'): + logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空") + continue + try: + json.loads(node_data['memory_items']) + except json.JSONDecodeError: + logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串") + continue + await self._async_create_node(node_data) + except Exception as e: + logger.error(f"[同步] 创建节点失败: {node_data['concept']}, 错误: {e}") + # 从图中移除问题节点 + self.memory_graph.G.remove_node(node_data['concept']) + node_create_end = time.time() + logger.info(f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)") + + # 异步批量更新节点 + node_update_start = time.time() + if nodes_to_update: + # 按批次更新节点,每批100个 + batch_size = 100 + for i in range(0, len(nodes_to_update), batch_size): + batch = nodes_to_update[i:i + batch_size] + try: + # 验证批次中的每个节点数据 + valid_batch = [] + for node_data in batch: + # 确保 memory_items 不为空且是有效的 JSON 字符串 + if not node_data.get('memory_items'): + logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 为空") + continue + try: + # 验证 JSON 字符串是否有效 + json.loads(node_data['memory_items']) + valid_batch.append(node_data) + except json.JSONDecodeError: + logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串") + continue + + if not valid_batch: + logger.warning(f"[同步] 批次 {i//batch_size + 1} 没有有效的节点可以更新") + continue + + # 异步批量更新节点 + await self._async_batch_update_nodes(valid_batch) + logger.debug(f"[同步] 成功更新批次 {i//batch_size + 1} 中的 {len(valid_batch)} 个节点") + except Exception as e: + logger.error(f"[同步] 批量更新节点失败: {e}") + # 如果批量更新失败,尝试逐个更新 + for node_data in valid_batch: + try: + await self._async_update_node(node_data) + except Exception as e: + logger.error(f"[同步] 更新节点失败: {node_data['concept']}, 错误: {e}") + # 从图中移除问题节点 + self.memory_graph.G.remove_node(node_data['concept']) + + node_update_end = time.time() + logger.info(f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)") + + # 异步删除不存在的节点 + node_delete_start = time.time() + if nodes_to_delete: + await self._async_delete_nodes(nodes_to_delete) + node_delete_end = time.time() + logger.info(f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)") # 处理边的信息 + edge_load_start = time.time() db_edges = list(GraphEdges.select()) memory_edges = list(self.memory_graph.G.edges(data=True)) + edge_load_end = time.time() + logger.info(f"[同步] 加载边数据耗时: {edge_load_end - edge_load_start:.2f}秒") # 创建边的哈希值字典 + edge_dict_start = time.time() db_edge_dict = {} for edge in db_edges: edge_hash = self.hippocampus.calculate_edge_hash(edge.source, edge.target) db_edge_dict[(edge.source, edge.target)] = {"hash": edge_hash, "strength": edge.strength} + edge_dict_end = time.time() + logger.info(f"[同步] 创建边字典耗时: {edge_dict_end - edge_dict_start:.2f}秒") + + # 批量准备边数据 + edges_to_create = [] + edges_to_update = [] # 检查并更新边 + edge_process_start = time.time() for source, target, data in memory_edges: edge_hash = self.hippocampus.calculate_edge_hash(source, target) edge_key = (source, target) strength = data.get("strength", 1) # 获取边的时间信息 - created_time = data.get("created_time", datetime.datetime.now().timestamp()) - last_modified = data.get("last_modified", datetime.datetime.now().timestamp()) + created_time = data.get("created_time", current_time) + last_modified = data.get("last_modified", current_time) if edge_key not in db_edge_dict: - # 添加新边 - GraphEdges.create( - source=source, - target=target, - strength=strength, - hash=edge_hash, - created_time=created_time, - last_modified=last_modified, - ) + # 添加新边到创建列表 + edges_to_create.append({ + 'source': source, + 'target': target, + 'strength': strength, + 'hash': edge_hash, + 'created_time': created_time, + 'last_modified': last_modified + }) else: # 检查边的特征值是否变化 if db_edge_dict[edge_key]["hash"] != edge_hash: - edge = GraphEdges.get(GraphEdges.source == source, GraphEdges.target == target) - edge.hash = edge_hash - edge.strength = strength - edge.last_modified = last_modified - edge.save() + edges_to_update.append({ + 'source': source, + 'target': target, + 'strength': strength, + 'hash': edge_hash, + 'last_modified': last_modified + }) + edge_process_end = time.time() + logger.info(f"[同步] 处理边数据耗时: {edge_process_end - edge_process_start:.2f}秒") + + # 异步批量创建新边 + edge_create_start = time.time() + if edges_to_create: + batch_size = 100 + for i in range(0, len(edges_to_create), batch_size): + batch = edges_to_create[i:i + batch_size] + await self._async_batch_create_edges(batch) + edge_create_end = time.time() + logger.info(f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)") + + # 异步批量更新边 + edge_update_start = time.time() + if edges_to_update: + batch_size = 100 + for i in range(0, len(edges_to_update), batch_size): + batch = edges_to_update[i:i + batch_size] + await self._async_batch_update_edges(batch) + edge_update_end = time.time() + logger.info(f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)") + + # 检查需要删除的边 + memory_edge_keys = {(source, target) for source, target, _ in memory_edges} + db_edge_keys = {(edge.source, edge.target) for edge in db_edges} + edges_to_delete = db_edge_keys - memory_edge_keys + + # 异步删除不存在的边 + edge_delete_start = time.time() + if edges_to_delete: + await self._async_delete_edges(edges_to_delete) + edge_delete_end = time.time() + logger.info(f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)") + + end_time = time.time() + logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒") + logger.success(f"[同步] 同步了 {len(memory_nodes)} 个节点和 {len(memory_edges)} 条边") + + async def _async_batch_create_nodes(self, nodes_data): + """异步批量创建节点""" + try: + GraphNodes.insert_many(nodes_data).execute() + except Exception as e: + logger.error(f"[同步] 批量创建节点失败: {e}") + raise + + async def _async_create_node(self, node_data): + """异步创建单个节点""" + try: + GraphNodes.create(**node_data) + except Exception as e: + logger.error(f"[同步] 创建节点失败: {e}") + raise + + async def _async_batch_update_nodes(self, nodes_data): + """异步批量更新节点""" + try: + for node_data in nodes_data: + GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where( + GraphNodes.concept == node_data['concept'] + ).execute() + except Exception as e: + logger.error(f"[同步] 批量更新节点失败: {e}") + raise + + async def _async_update_node(self, node_data): + """异步更新单个节点""" + try: + GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where( + GraphNodes.concept == node_data['concept'] + ).execute() + except Exception as e: + logger.error(f"[同步] 更新节点失败: {e}") + raise + + async def _async_delete_nodes(self, concepts): + """异步删除节点""" + try: + GraphNodes.delete().where(GraphNodes.concept.in_(concepts)).execute() + except Exception as e: + logger.error(f"[同步] 删除节点失败: {e}") + raise + + async def _async_batch_create_edges(self, edges_data): + """异步批量创建边""" + try: + GraphEdges.insert_many(edges_data).execute() + except Exception as e: + logger.error(f"[同步] 批量创建边失败: {e}") + raise + + async def _async_batch_update_edges(self, edges_data): + """异步批量更新边""" + try: + for edge_data in edges_data: + GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ['source', 'target']}).where( + (GraphEdges.source == edge_data['source']) & + (GraphEdges.target == edge_data['target']) + ).execute() + except Exception as e: + logger.error(f"[同步] 批量更新边失败: {e}") + raise + + async def _async_delete_edges(self, edge_keys): + """异步删除边""" + try: + for source, target in edge_keys: + GraphEdges.delete().where( + (GraphEdges.source == source) & + (GraphEdges.target == target) + ).execute() + except Exception as e: + logger.error(f"[同步] 删除边失败: {e}") + raise def sync_memory_from_db(self): """从数据库同步数据到内存中的图结构""" @@ -1111,7 +1398,7 @@ class ParahippocampalGyrus: input_text = await build_readable_messages( messages, merge_messages=True, # 合并连续消息 - timestamp_mode="normal", # 使用 'YYYY-MM-DD HH:MM:SS' 格式 + timestamp_mode="normal_no_YMD", # 使用 'YYYY-MM-DD HH:MM:SS' 格式 replace_bot_name=False, # 保留原始用户名 ) @@ -1119,8 +1406,12 @@ class ParahippocampalGyrus: if not input_text: logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。") return set(), {} + + current_YMD_time = datetime.datetime.now().strftime("%Y-%m-%d") + current_YMD_time_str = f"当前日期: {current_YMD_time}" + input_text = f"{current_YMD_time_str}\n{input_text}" - logger.debug(f"用于压缩的格式化文本:\n{input_text}") + logger.debug(f"记忆来源:\n{input_text}") # 2. 使用LLM提取关键主题 topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate) @@ -1191,7 +1482,7 @@ class ParahippocampalGyrus: return compressed_memory, similar_topics_dict async def operation_build_memory(self): - logger.debug("------------------------------------开始构建记忆--------------------------------------") + logger.info("------------------------------------开始构建记忆--------------------------------------") start_time = time.time() memory_samples = self.hippocampus.entorhinal_cortex.get_memory_sample() all_added_nodes = [] @@ -1199,19 +1490,16 @@ class ParahippocampalGyrus: all_added_edges = [] for i, messages in enumerate(memory_samples, 1): all_topics = [] - progress = (i / len(memory_samples)) * 100 - bar_length = 30 - filled_length = int(bar_length * i // len(memory_samples)) - bar = "█" * filled_length + "-" * (bar_length - filled_length) - logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") - compress_rate = global_config.memory.memory_compress_rate try: compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) except Exception as e: logger.error(f"压缩记忆时发生错误: {e}") continue - logger.debug(f"压缩后记忆数量: {compressed_memory},似曾相识的话题: {similar_topics_dict}") + for topic, memory in compressed_memory: + logger.info(f"取得记忆: {topic} - {memory}") + for topic, similar_topics in similar_topics_dict.items(): + logger.debug(f"相似话题: {topic} - {similar_topics}") current_time = datetime.datetime.now().timestamp() logger.debug(f"添加节点: {', '.join(topic for topic, _ in compressed_memory)}") @@ -1245,10 +1533,20 @@ class ParahippocampalGyrus: logger.debug(f"连接同批次节点: {topic1} 和 {topic2}") all_added_edges.append(f"{topic1}-{topic2}") self.memory_graph.connect_dot(topic1, topic2) + + + progress = (i / len(memory_samples)) * 100 + bar_length = 30 + filled_length = int(bar_length * i // len(memory_samples)) + bar = "█" * filled_length + "-" * (bar_length - filled_length) + logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") - logger.success(f"更新记忆: {', '.join(all_added_nodes)}") - logger.debug(f"强化连接: {', '.join(all_added_edges)}") - logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") + if all_added_nodes: + logger.success(f"更新记忆: {', '.join(all_added_nodes)}") + if all_added_edges: + logger.debug(f"强化连接: {', '.join(all_added_edges)}") + if all_connected_nodes: + logger.info(f"强化连接节点: {', '.join(all_connected_nodes)}") await self.hippocampus.entorhinal_cortex.sync_memory_to_db() diff --git a/src/chat/normal_chat/normal_chat.py b/src/chat/normal_chat/normal_chat.py index 8c724c83a..4b8e0ab63 100644 --- a/src/chat/normal_chat/normal_chat.py +++ b/src/chat/normal_chat/normal_chat.py @@ -24,6 +24,7 @@ from src.chat.focus_chat.planners.action_manager import ActionManager from src.chat.normal_chat.normal_chat_planner import NormalChatPlanner from src.chat.normal_chat.normal_chat_action_modifier import NormalChatActionModifier from src.chat.normal_chat.normal_chat_expressor import NormalChatExpressor +from src.chat.focus_chat.replyer.default_replyer import DefaultReplyer logger = get_logger("normal_chat") @@ -77,6 +78,9 @@ class NormalChat: # 初始化Normal Chat专用表达器 self.expressor = NormalChatExpressor(self.chat_stream, self.stream_name) + self.replyer = DefaultReplyer(chat_id=self.stream_id) + + self.replyer.chat_stream = self.chat_stream self._initialized = True logger.debug(f"[{self.stream_name}] NormalChat 初始化完成 (异步部分)。") @@ -93,7 +97,7 @@ class NormalChat: ) thinking_time_point = round(time.time(), 2) - thinking_id = "mt" + str(thinking_time_point) + thinking_id = "tid" + str(thinking_time_point) thinking_message = MessageThinking( message_id=thinking_id, chat_stream=self.chat_stream, @@ -232,7 +236,6 @@ class NormalChat: message=message, is_mentioned=is_mentioned, interested_rate=interest_value * self.willing_amplifier, - rewind_response=False, ) except Exception as e: logger.error(f"[{self.stream_name}] 处理兴趣消息{msg_id}时出错: {e}\n{traceback.format_exc()}") @@ -241,7 +244,7 @@ class NormalChat: # 改为实例方法, 移除 chat 参数 async def normal_response( - self, message: MessageRecv, is_mentioned: bool, interested_rate: float, rewind_response: bool = False + self, message: MessageRecv, is_mentioned: bool, interested_rate: float ) -> None: # 新增:如果已停用,直接返回 if self._disabled: @@ -284,11 +287,8 @@ class NormalChat: # 回复前处理 await willing_manager.before_generate_reply_handle(message.message_info.message_id) - with Timer("创建思考消息", timing_results): - if rewind_response: - thinking_id = await self._create_thinking_message(message, message.message_info.time) - else: - thinking_id = await self._create_thinking_message(message) + + thinking_id = await self._create_thinking_message(message) logger.debug(f"[{self.stream_name}] 创建捕捉器,thinking_id:{thinking_id}") @@ -666,6 +666,7 @@ class NormalChat: thinking_id=thinking_id, observations=[], # normal_chat不使用observations expressor=self.expressor, # 使用normal_chat专用的expressor + replyer=self.replyer, chat_stream=self.chat_stream, log_prefix=self.stream_name, shutting_down=self._disabled, diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index e896420aa..59cac2139 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -342,7 +342,7 @@ async def _build_readable_messages_internal( # 使用指定的 timestamp_mode 格式化时间 readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode) - header = f"{readable_time}{merged['name']} 说:" + header = f"{readable_time}, {merged['name']} :" output_lines.append(header) # 将内容合并,并添加缩进 for line in merged["content"]: diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 0e1113aff..19bbfe2c4 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -128,38 +128,38 @@ class ImageManager: return f"[表情包,含义看起来是:{cached_description}]" # 根据配置决定是否保存图片 - if global_config.emoji.save_emoji: - # 生成文件名和路径 - logger.debug(f"保存表情包: {image_hash}") - current_timestamp = time.time() - filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" - emoji_dir = os.path.join(self.IMAGE_DIR, "emoji") - os.makedirs(emoji_dir, exist_ok=True) - file_path = os.path.join(emoji_dir, filename) + # if global_config.emoji.save_emoji: + # 生成文件名和路径 + logger.debug(f"保存表情包: {image_hash}") + current_timestamp = time.time() + filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}" + emoji_dir = os.path.join(self.IMAGE_DIR, "emoji") + os.makedirs(emoji_dir, exist_ok=True) + file_path = os.path.join(emoji_dir, filename) + try: + # 保存文件 + with open(file_path, "wb") as f: + f.write(image_bytes) + + # 保存到数据库 (Images表) try: - # 保存文件 - with open(file_path, "wb") as f: - f.write(image_bytes) - - # 保存到数据库 (Images表) - try: - img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji")) - img_obj.path = file_path - img_obj.description = description - img_obj.timestamp = current_timestamp - img_obj.save() - except Images.DoesNotExist: - Images.create( - emoji_hash=image_hash, - path=file_path, - type="emoji", - description=description, - timestamp=current_timestamp, - ) - # logger.debug(f"保存表情包元数据: {file_path}") - except Exception as e: - logger.error(f"保存表情包文件或元数据失败: {str(e)}") + img_obj = Images.get((Images.emoji_hash == image_hash) & (Images.type == "emoji")) + img_obj.path = file_path + img_obj.description = description + img_obj.timestamp = current_timestamp + img_obj.save() + except Images.DoesNotExist: + Images.create( + emoji_hash=image_hash, + path=file_path, + type="emoji", + description=description, + timestamp=current_timestamp, + ) + # logger.debug(f"保存表情包元数据: {file_path}") + except Exception as e: + logger.error(f"保存表情包文件或元数据失败: {str(e)}") # 保存描述到数据库 (ImageDescriptions表) self._save_description_to_db(image_hash, description, "emoji") diff --git a/src/common/remote.py b/src/common/remote.py index 064a07cb0..b61a43d8c 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -124,9 +124,7 @@ class TelemetryHeartBeatTask(AsyncTask): timeout=5, # 设置超时时间为5秒 ) except Exception as e: - # 你知道为什么设置成debug吗? - # 因为我不想看到 - logger.debug(f"心跳发送失败: {e}") + logger.warning(f"(此错误不会影响正常使用)状态未发生: {e}") logger.debug(response) @@ -136,21 +134,21 @@ class TelemetryHeartBeatTask(AsyncTask): logger.debug(f"心跳发送成功,状态码: {response.status_code}") elif response.status_code == 403: # 403 Forbidden - logger.error( - "心跳发送失败,403 Forbidden: 可能是UUID无效或未注册。" + logger.warning( + "(此错误不会影响正常使用)心跳发送失败,403 Forbidden: 可能是UUID无效或未注册。" "处理措施:重置UUID,下次发送心跳时将尝试重新注册。" ) self.client_uuid = None del local_storage["mmc_uuid"] # 删除本地存储的UUID else: # 其他错误 - logger.error(f"心跳发送失败,状态码: {response.status_code}, 响应内容: {response.text}") + logger.warning(f"(此错误不会影响正常使用)状态未发送,状态码: {response.status_code}, 响应内容: {response.text}") async def run(self): # 发送心跳 if global_config.telemetry.enable: if self.client_uuid is None and not await self._req_uuid(): - logger.error("获取UUID失败,跳过此次心跳") + logger.warning("获取UUID失败,跳过此次心跳") return await self._send_heartbeat() diff --git a/src/individuality/expression_style.py b/src/individuality/expression_style.py index 29b687076..0d650ce46 100644 --- a/src/individuality/expression_style.py +++ b/src/individuality/expression_style.py @@ -6,6 +6,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from typing import List, Tuple import os import json +from datetime import datetime logger = get_logger("expressor") @@ -45,11 +46,30 @@ class PersonalityExpression: if os.path.exists(self.meta_file_path): try: with open(self.meta_file_path, "r", encoding="utf-8") as f: - return json.load(f) + meta_data = json.load(f) + # 检查是否有last_update_time字段 + if "last_update_time" not in meta_data: + logger.warning(f"{self.meta_file_path} 中缺少last_update_time字段,将重新开始。") + # 清空并重写元数据文件 + self._write_meta_data({"last_style_text": None, "count": 0, "last_update_time": None}) + # 清空并重写表达文件 + if os.path.exists(self.expressions_file_path): + with open(self.expressions_file_path, "w", encoding="utf-8") as f: + json.dump([], f, ensure_ascii=False, indent=2) + logger.debug(f"已清空表达文件: {self.expressions_file_path}") + return {"last_style_text": None, "count": 0, "last_update_time": None} + return meta_data except json.JSONDecodeError: logger.warning(f"无法解析 {self.meta_file_path} 中的JSON数据,将重新开始。") - return {"last_style_text": None, "count": 0} - return {"last_style_text": None, "count": 0} + # 清空并重写元数据文件 + self._write_meta_data({"last_style_text": None, "count": 0, "last_update_time": None}) + # 清空并重写表达文件 + if os.path.exists(self.expressions_file_path): + with open(self.expressions_file_path, "w", encoding="utf-8") as f: + json.dump([], f, ensure_ascii=False, indent=2) + logger.debug(f"已清空表达文件: {self.expressions_file_path}") + return {"last_style_text": None, "count": 0, "last_update_time": None} + return {"last_style_text": None, "count": 0, "last_update_time": None} def _write_meta_data(self, data): os.makedirs(os.path.dirname(self.meta_file_path), exist_ok=True) @@ -84,7 +104,7 @@ class PersonalityExpression: if count >= self.max_calculations: logger.debug(f"对于风格 '{current_style_text}' 已达到最大计算次数 ({self.max_calculations})。跳过提取。") # 即使跳过,也更新元数据以反映当前风格已被识别且计数已满 - self._write_meta_data({"last_style_text": current_style_text, "count": count}) + self._write_meta_data({"last_style_text": current_style_text, "count": count, "last_update_time": meta_data.get("last_update_time")}) return # 构建prompt @@ -99,30 +119,63 @@ class PersonalityExpression: except Exception as e: logger.error(f"个性表达方式提取失败: {e}") # 如果提取失败,保存当前的风格和未增加的计数 - self._write_meta_data({"last_style_text": current_style_text, "count": count}) + self._write_meta_data({"last_style_text": current_style_text, "count": count, "last_update_time": meta_data.get("last_update_time")}) return logger.info(f"个性表达方式提取response: {response}") # chat_id用personality - expressions = self.parse_expression_response(response, "personality") + # 转为dict并count=100 - result = [] - for _, situation, style in expressions: - result.append({"situation": situation, "style": style, "count": 100}) - # 超过50条时随机删除多余的,只保留50条 - if len(result) > 50: - remove_count = len(result) - 50 - remove_indices = set(random.sample(range(len(result)), remove_count)) - result = [item for idx, item in enumerate(result) if idx not in remove_indices] + if response != "": + expressions = self.parse_expression_response(response, "personality") + # 读取已有的表达方式 + existing_expressions = [] + if os.path.exists(self.expressions_file_path): + try: + with open(self.expressions_file_path, "r", encoding="utf-8") as f: + existing_expressions = json.load(f) + except (json.JSONDecodeError, FileNotFoundError): + logger.warning(f"无法读取或解析 {self.expressions_file_path},将创建新的表达文件。") + + # 创建新的表达方式 + new_expressions = [] + for _, situation, style in expressions: + new_expressions.append({"situation": situation, "style": style, "count": 1}) + + # 合并表达方式,如果situation和style相同则累加count + merged_expressions = existing_expressions.copy() + for new_expr in new_expressions: + found = False + for existing_expr in merged_expressions: + if (existing_expr["situation"] == new_expr["situation"] and + existing_expr["style"] == new_expr["style"]): + existing_expr["count"] += new_expr["count"] + found = True + break + if not found: + merged_expressions.append(new_expr) + + # 超过50条时随机删除多余的,只保留50条 + if len(merged_expressions) > 50: + remove_count = len(merged_expressions) - 50 + remove_indices = set(random.sample(range(len(merged_expressions)), remove_count)) + merged_expressions = [item for idx, item in enumerate(merged_expressions) if idx not in remove_indices] - with open(self.expressions_file_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - logger.info(f"已写入{len(result)}条表达到{self.expressions_file_path}") + with open(self.expressions_file_path, "w", encoding="utf-8") as f: + json.dump(merged_expressions, f, ensure_ascii=False, indent=2) + logger.info(f"已写入{len(merged_expressions)}条表达到{self.expressions_file_path}") - # 成功提取后更新元数据 - count += 1 - self._write_meta_data({"last_style_text": current_style_text, "count": count}) - logger.info(f"成功处理。风格 '{current_style_text}' 的计数现在是 {count}。") + # 成功提取后更新元数据 + count += 1 + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + self._write_meta_data({ + "last_style_text": current_style_text, + "count": count, + "last_update_time": current_time + }) + logger.info(f"成功处理。风格 '{current_style_text}' 的计数现在是 {count},最后更新时间:{current_time}。") + else: + logger.warning(f"个性表达方式提取失败,模型返回空内容: {response}") def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]: """ diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index bc1844c1b..9334405d4 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "2.10.0" +version = "2.11.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更