Merge pull request #1020 from tcmofashi/dev

feat: 为normal_chat增加change_to_focus_chat的action
This commit is contained in:
tcmofashi
2025-06-03 16:11:47 +08:00
committed by GitHub
15 changed files with 445 additions and 362 deletions

View File

@@ -1,33 +1,34 @@
import json
import os
from pathlib import Path
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
import matplotlib as mpl
import sqlite3
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 使用微软雅黑
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei"] # 使用微软雅黑
plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
plt.rcParams["font.family"] = "sans-serif"
# 获取脚本所在目录
SCRIPT_DIR = Path(__file__).parent
def get_group_name(stream_id):
"""从数据库中获取群组名称"""
conn = sqlite3.connect('data/maibot.db')
conn = sqlite3.connect("data/maibot.db")
cursor = conn.cursor()
cursor.execute('''
cursor.execute(
"""
SELECT group_name, user_nickname, platform
FROM chat_streams
WHERE stream_id = ?
''', (stream_id,))
""",
(stream_id,),
)
result = cursor.fetchone()
conn.close()
@@ -42,13 +43,14 @@ def get_group_name(stream_id):
return f"{platform}-{stream_id[:8]}"
return stream_id
def load_group_data(group_dir):
"""加载单个群组的数据"""
json_path = Path(group_dir) / "expressions.json"
if not json_path.exists():
return [], [], []
with open(json_path, 'r', encoding='utf-8') as f:
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
situations = []
@@ -56,13 +58,14 @@ def load_group_data(group_dir):
combined = []
for item in data:
count = item['count']
situations.extend([item['situation']] * count)
styles.extend([item['style']] * count)
count = item["count"]
situations.extend([item["situation"]] * count)
styles.extend([item["style"]] * count)
combined.extend([f"{item['situation']} {item['style']}"] * count)
return situations, styles, combined
def analyze_group_similarity():
# 获取所有群组目录
base_dir = Path("data/expression/learnt_style")
@@ -79,9 +82,9 @@ def analyze_group_similarity():
for d in group_dirs:
situations, styles, combined = load_group_data(d)
group_situations.append(' '.join(situations))
group_styles.append(' '.join(styles))
group_combined.append(' '.join(combined))
group_situations.append(" ".join(situations))
group_styles.append(" ".join(styles))
group_combined.append(" ".join(combined))
# 创建TF-IDF向量化器
vectorizer = TfidfVectorizer()
@@ -101,60 +104,66 @@ def analyze_group_similarity():
# 场景相似度热力图
plt.subplot(1, 3, 1)
sns.heatmap(log_situation_matrix,
sns.heatmap(
log_situation_matrix,
xticklabels=group_names,
yticklabels=group_names,
cmap='YlOrRd',
cmap="YlOrRd",
annot=True,
fmt='.2f',
fmt=".2f",
vmin=0,
vmax=np.log1p(0.2))
plt.title('群组场景相似度热力图 (对数变换)')
plt.xticks(rotation=45, ha='right')
vmax=np.log1p(0.2),
)
plt.title("群组场景相似度热力图 (对数变换)")
plt.xticks(rotation=45, ha="right")
# 表达方式相似度热力图
plt.subplot(1, 3, 2)
sns.heatmap(log_style_matrix,
sns.heatmap(
log_style_matrix,
xticklabels=group_names,
yticklabels=group_names,
cmap='YlOrRd',
cmap="YlOrRd",
annot=True,
fmt='.2f',
fmt=".2f",
vmin=0,
vmax=np.log1p(0.2))
plt.title('群组表达方式相似度热力图 (对数变换)')
plt.xticks(rotation=45, ha='right')
vmax=np.log1p(0.2),
)
plt.title("群组表达方式相似度热力图 (对数变换)")
plt.xticks(rotation=45, ha="right")
# 组合相似度热力图
plt.subplot(1, 3, 3)
sns.heatmap(log_combined_matrix,
sns.heatmap(
log_combined_matrix,
xticklabels=group_names,
yticklabels=group_names,
cmap='YlOrRd',
cmap="YlOrRd",
annot=True,
fmt='.2f',
fmt=".2f",
vmin=0,
vmax=np.log1p(0.2))
plt.title('群组场景+表达方式相似度热力图 (对数变换)')
plt.xticks(rotation=45, ha='right')
vmax=np.log1p(0.2),
)
plt.title("群组场景+表达方式相似度热力图 (对数变换)")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.savefig(SCRIPT_DIR / 'group_similarity_heatmaps.png', dpi=300, bbox_inches='tight')
plt.savefig(SCRIPT_DIR / "group_similarity_heatmaps.png", dpi=300, bbox_inches="tight")
plt.close()
# 保存匹配详情到文本文件
with open(SCRIPT_DIR / 'group_similarity_details.txt', 'w', encoding='utf-8') as f:
f.write('群组相似度详情\n')
f.write('=' * 50 + '\n\n')
with open(SCRIPT_DIR / "group_similarity_details.txt", "w", encoding="utf-8") as f:
f.write("群组相似度详情\n")
f.write("=" * 50 + "\n\n")
for i in range(len(group_ids)):
for j in range(i+1, len(group_ids)):
for j in range(i + 1, len(group_ids)):
if log_combined_matrix[i][j] > np.log1p(0.05):
f.write(f'群组1: {group_names[i]}\n')
f.write(f'群组2: {group_names[j]}\n')
f.write(f'场景相似度: {situation_matrix[i][j]:.4f}\n')
f.write(f'表达方式相似度: {style_matrix[i][j]:.4f}\n')
f.write(f'组合相似度: {combined_matrix[i][j]:.4f}\n')
f.write(f"群组1: {group_names[i]}\n")
f.write(f"群组2: {group_names[j]}\n")
f.write(f"场景相似度: {situation_matrix[i][j]:.4f}\n")
f.write(f"表达方式相似度: {style_matrix[i][j]:.4f}\n")
f.write(f"组合相似度: {combined_matrix[i][j]:.4f}\n")
# 获取两个群组的数据
situations1, styles1, _ = load_group_data(group_dirs[i])
@@ -163,18 +172,19 @@ def analyze_group_similarity():
# 找出共同的场景
common_situations = set(situations1) & set(situations2)
if common_situations:
f.write('\n共同场景:\n')
f.write("\n共同场景:\n")
for situation in common_situations:
f.write(f'- {situation}\n')
f.write(f"- {situation}\n")
# 找出共同的表达方式
common_styles = set(styles1) & set(styles2)
if common_styles:
f.write('\n共同表达方式:\n')
f.write("\n共同表达方式:\n")
for style in common_styles:
f.write(f'- {style}\n')
f.write(f"- {style}\n")
f.write("\n" + "-" * 50 + "\n\n")
f.write('\n' + '-' * 50 + '\n\n')
if __name__ == "__main__":
analyze_group_similarity()

View File

@@ -32,7 +32,6 @@ from rich.panel import Panel
from src.common.database.database import db
from src.common.database.database_model import (
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,

View File

@@ -8,9 +8,9 @@ import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from collections import defaultdict
class ExpressionViewer:
def __init__(self, root):
self.root = root
@@ -30,7 +30,7 @@ class ExpressionViewer:
self.search_frame.pack(fill=tk.X, pady=(0, 10))
self.search_var = tk.StringVar()
self.search_var.trace('w', self.filter_expressions)
self.search_var.trace("w", self.filter_expressions)
self.search_entry = ttk.Entry(self.search_frame, textvariable=self.search_var)
self.search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True)
ttk.Label(self.search_frame, text="搜索:").pack(side=tk.LEFT, padx=(0, 5))
@@ -39,40 +39,51 @@ class ExpressionViewer:
self.file_var = tk.StringVar()
self.file_combo = ttk.Combobox(self.search_frame, textvariable=self.file_var)
self.file_combo.pack(side=tk.LEFT, padx=5)
self.file_combo.bind('<<ComboboxSelected>>', self.load_file)
self.file_combo.bind("<<ComboboxSelected>>", self.load_file)
# 创建排序选项
self.sort_frame = ttk.LabelFrame(self.control_frame, text="排序选项")
self.sort_frame.pack(fill=tk.X, pady=5)
self.sort_var = tk.StringVar(value="count")
ttk.Radiobutton(self.sort_frame, text="按计数排序", variable=self.sort_var,
value="count", command=self.apply_sort).pack(anchor=tk.W)
ttk.Radiobutton(self.sort_frame, text="按情境排序", variable=self.sort_var,
value="situation", command=self.apply_sort).pack(anchor=tk.W)
ttk.Radiobutton(self.sort_frame, text="风格排序", variable=self.sort_var,
value="style", command=self.apply_sort).pack(anchor=tk.W)
ttk.Radiobutton(
self.sort_frame, text="按计数排序", variable=self.sort_var, value="count", command=self.apply_sort
).pack(anchor=tk.W)
ttk.Radiobutton(
self.sort_frame, text="情境排序", variable=self.sort_var, value="situation", command=self.apply_sort
).pack(anchor=tk.W)
ttk.Radiobutton(
self.sort_frame, text="按风格排序", variable=self.sort_var, value="style", command=self.apply_sort
).pack(anchor=tk.W)
# 创建分群选项
self.group_frame = ttk.LabelFrame(self.control_frame, text="分群选项")
self.group_frame.pack(fill=tk.X, pady=5)
self.group_var = tk.StringVar(value="none")
ttk.Radiobutton(self.group_frame, text="不分群", variable=self.group_var,
value="none", command=self.apply_grouping).pack(anchor=tk.W)
ttk.Radiobutton(self.group_frame, text="按情境分群", variable=self.group_var,
value="situation", command=self.apply_grouping).pack(anchor=tk.W)
ttk.Radiobutton(self.group_frame, text="风格分群", variable=self.group_var,
value="style", command=self.apply_grouping).pack(anchor=tk.W)
ttk.Radiobutton(
self.group_frame, text="不分群", variable=self.group_var, value="none", command=self.apply_grouping
).pack(anchor=tk.W)
ttk.Radiobutton(
self.group_frame, text="情境分群", variable=self.group_var, value="situation", command=self.apply_grouping
).pack(anchor=tk.W)
ttk.Radiobutton(
self.group_frame, text="按风格分群", variable=self.group_var, value="style", command=self.apply_grouping
).pack(anchor=tk.W)
# 创建相似度阈值滑块
self.similarity_frame = ttk.LabelFrame(self.control_frame, text="相似度设置")
self.similarity_frame.pack(fill=tk.X, pady=5)
self.similarity_var = tk.DoubleVar(value=0.5)
self.similarity_scale = ttk.Scale(self.similarity_frame, from_=0.0, to=1.0,
variable=self.similarity_var, orient=tk.HORIZONTAL,
command=self.update_similarity)
self.similarity_scale = ttk.Scale(
self.similarity_frame,
from_=0.0,
to=1.0,
variable=self.similarity_var,
orient=tk.HORIZONTAL,
command=self.update_similarity,
)
self.similarity_scale.pack(fill=tk.X, padx=5, pady=5)
ttk.Label(self.similarity_frame, text="相似度阈值: 0.5").pack()
@@ -81,8 +92,9 @@ class ExpressionViewer:
self.view_frame.pack(fill=tk.X, pady=5)
self.show_graph_var = tk.BooleanVar(value=True)
ttk.Checkbutton(self.view_frame, text="显示关系图", variable=self.show_graph_var,
command=self.toggle_graph).pack(anchor=tk.W)
ttk.Checkbutton(
self.view_frame, text="显示关系图", variable=self.show_graph_var, command=self.toggle_graph
).pack(anchor=tk.W)
# 创建右侧内容区域
self.content_frame = ttk.Frame(self.main_frame)
@@ -114,11 +126,11 @@ class ExpressionViewer:
files = []
for root, _, filenames in os.walk(expression_dir):
for filename in filenames:
if filename.endswith('.json'):
if filename.endswith(".json"):
rel_path = os.path.relpath(os.path.join(root, filename), expression_dir)
files.append(rel_path)
self.file_combo['values'] = files
self.file_combo["values"] = files
if files:
self.file_combo.set(files[0])
self.load_file(None)
@@ -130,7 +142,7 @@ class ExpressionViewer:
file_path = os.path.join("data/expression", selected_file)
try:
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
self.current_data = json.load(f)
self.apply_sort()
@@ -216,18 +228,17 @@ class ExpressionViewer:
pos = nx.spring_layout(self.graph)
# 绘制节点
nx.draw_networkx_nodes(self.graph, pos, node_color='lightblue',
node_size=1000, alpha=0.6)
nx.draw_networkx_nodes(self.graph, pos, node_color="lightblue", node_size=1000, alpha=0.6)
# 绘制边
nx.draw_networkx_edges(self.graph, pos, alpha=0.4)
# 添加标签
labels = nx.get_node_attributes(self.graph, 'label')
labels = nx.get_node_attributes(self.graph, "label")
nx.draw_networkx_labels(self.graph, pos, labels, font_size=8)
plt.title("表达方式关系图")
plt.axis('off')
plt.axis("off")
self.canvas = FigureCanvasTkAgg(fig, master=self.graph_frame)
self.canvas.draw()
@@ -249,17 +260,19 @@ class ExpressionViewer:
filtered_data = []
for item in self.current_data:
situation = item.get('situation', '').lower()
style = item.get('style', '').lower()
situation = item.get("situation", "").lower()
style = item.get("style", "").lower()
if search_text in situation or search_text in style:
filtered_data.append(item)
self.display_data(filtered_data)
def main():
root = tk.Tk()
app = ExpressionViewer(root)
# app = ExpressionViewer(root)
root.mainloop()
if __name__ == "__main__":
main()

View File

@@ -602,9 +602,10 @@ class EmojiManager:
continue
# 检查是否需要处理表情包(数量超过最大值或不足)
if global_config.emoji.steal_emoji and ((self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
self.emoji_num < self.emoji_num_max
)):
if global_config.emoji.steal_emoji and (
(self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace)
or (self.emoji_num < self.emoji_num_max)
):
try:
# 获取目录下所有图片文件
files_to_process = [

View File

@@ -264,9 +264,6 @@ class ActionPlanner(BasePlanner):
action_result = {"action_type": action, "action_data": action_data, "reasoning": reasoning}
plan_result = {
"action_result": action_result,
# "extra_info_block": extra_info_block,

View File

@@ -1,4 +1,4 @@
from typing import Dict, Any, List, Optional, Set, Tuple
from typing import Dict, Any, Tuple
import time
import random
import string

View File

@@ -286,110 +286,110 @@ class MemoryManager:
logger.error(f"生成总结时出错: {str(e)}")
return default_summary
# async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
# """
# 对记忆进行精简操作,根据要求修改要点、总结和概括
# async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
# """
# 对记忆进行精简操作,根据要求修改要点、总结和概括
# Args:
# memory_id: 记忆ID
# requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
# Args:
# memory_id: 记忆ID
# requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
# Returns:
# 修改后的记忆总结字典
# """
# # 获取指定ID的记忆项
# logger.info(f"精简记忆: {memory_id}")
# memory_item = self.get_by_id(memory_id)
# if not memory_item:
# raise ValueError(f"未找到ID为{memory_id}的记忆项")
# Returns:
# 修改后的记忆总结字典
# """
# # 获取指定ID的记忆项
# logger.info(f"精简记忆: {memory_id}")
# memory_item = self.get_by_id(memory_id)
# if not memory_item:
# raise ValueError(f"未找到ID为{memory_id}的记忆项")
# # 增加精简次数
# memory_item.increase_compress_count()
# # 增加精简次数
# memory_item.increase_compress_count()
# summary = memory_item.summary
# summary = memory_item.summary
# # 使用LLM根据要求对总结、概括和要点进行精简修改
# prompt = f"""
# 请根据以下要求,对记忆内容的主题和关键要点进行精简,模拟记忆的遗忘过程:
# 要求:{requirements}
# 你可以随机对关键要点进行压缩,模糊或者丢弃,修改后,同样修改主题
# # 使用LLM根据要求对总结、概括和要点进行精简修改
# prompt = f"""
# 请根据以下要求,对记忆内容的主题和关键要点进行精简,模拟记忆的遗忘过程:
# 要求:{requirements}
# 你可以随机对关键要点进行压缩,模糊或者丢弃,修改后,同样修改主题
# 目前主题:{summary["brief"]}
# 目前主题:{summary["brief"]}
# 目前关键要点:
# {chr(10).join([f"- {point}" for point in summary.get("points", [])])}
# 目前关键要点:
# {chr(10).join([f"- {point}" for point in summary.get("points", [])])}
# 请生成修改后的主题和关键要点,遵循以下格式:
# ```json
# {{
# "brief": "修改后的主题20字以内",
# "points": [
# "修改后的要点",
# "修改后的要点"
# ]
# }}
# ```
# 请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
# """
# # 定义默认的精简结果
# default_refined = {
# "brief": summary["brief"],
# "points": summary.get("points", ["未知的要点"])[:1], # 默认只保留第一个要点
# }
# 请生成修改后的主题和关键要点,遵循以下格式:
# ```json
# {{
# "brief": "修改后的主题20字以内",
# "points": [
# "修改后的要点",
# "修改后的要点"
# ]
# }}
# ```
# 请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
# """
# # 定义默认的精简结果
# default_refined = {
# "brief": summary["brief"],
# "points": summary.get("points", ["未知的要点"])[:1], # 默认只保留第一个要点
# }
# try:
# # 调用LLM修改总结、概括和要点
# response, _ = await self.llm_summarizer.generate_response_async(prompt)
# logger.debug(f"精简记忆响应: {response}")
# # 使用repair_json处理响应
# try:
# # 修复JSON格式
# fixed_json_string = repair_json(response)
# try:
# # 调用LLM修改总结、概括和要点
# response, _ = await self.llm_summarizer.generate_response_async(prompt)
# logger.debug(f"精简记忆响应: {response}")
# # 使用repair_json处理响应
# try:
# # 修复JSON格式
# fixed_json_string = repair_json(response)
# # 将修复后的字符串解析为Python对象
# if isinstance(fixed_json_string, str):
# try:
# refined_data = json.loads(fixed_json_string)
# except json.JSONDecodeError as decode_error:
# logger.error(f"JSON解析错误: {str(decode_error)}")
# refined_data = default_refined
# else:
# # 如果repair_json直接返回了字典对象直接使用
# refined_data = fixed_json_string
# # 将修复后的字符串解析为Python对象
# if isinstance(fixed_json_string, str):
# try:
# refined_data = json.loads(fixed_json_string)
# except json.JSONDecodeError as decode_error:
# logger.error(f"JSON解析错误: {str(decode_error)}")
# refined_data = default_refined
# else:
# # 如果repair_json直接返回了字典对象直接使用
# refined_data = fixed_json_string
# # 确保是字典类型
# if not isinstance(refined_data, dict):
# logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
# refined_data = default_refined
# # 确保是字典类型
# if not isinstance(refined_data, dict):
# logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
# refined_data = default_refined
# # 更新总结
# summary["brief"] = refined_data.get("brief", "主题未知的记忆")
# # 更新总结
# summary["brief"] = refined_data.get("brief", "主题未知的记忆")
# # 更新关键要点
# points = refined_data.get("points", [])
# if isinstance(points, list) and points:
# # 确保所有要点都是字符串
# summary["points"] = [str(point) for point in points if point is not None]
# else:
# # 如果points不是列表或为空使用默认值
# summary["points"] = ["主要要点已遗忘"]
# # 更新关键要点
# points = refined_data.get("points", [])
# if isinstance(points, list) and points:
# # 确保所有要点都是字符串
# summary["points"] = [str(point) for point in points if point is not None]
# else:
# # 如果points不是列表或为空使用默认值
# summary["points"] = ["主要要点已遗忘"]
# except Exception as e:
# logger.error(f"精简记忆出错: {str(e)}")
# traceback.print_exc()
# except Exception as e:
# logger.error(f"精简记忆出错: {str(e)}")
# traceback.print_exc()
# # 出错时使用简化的默认精简
# summary["brief"] = summary["brief"] + " (已简化)"
# summary["points"] = summary.get("points", ["未知的要点"])[:1]
# # 出错时使用简化的默认精简
# summary["brief"] = summary["brief"] + " (已简化)"
# summary["points"] = summary.get("points", ["未知的要点"])[:1]
# except Exception as e:
# logger.error(f"精简记忆调用LLM出错: {str(e)}")
# traceback.print_exc()
# except Exception as e:
# logger.error(f"精简记忆调用LLM出错: {str(e)}")
# traceback.print_exc()
# # 更新原记忆项的总结
# memory_item.set_summary(summary)
# # 更新原记忆项的总结
# memory_item.set_summary(summary)
# return memory_item
# return memory_item
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
"""

View File

@@ -1,6 +1,5 @@
from typing import List, Any, Optional
import asyncio
import random
from src.common.logger_manager import get_logger
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem

View File

@@ -24,7 +24,6 @@ from rich.traceback import install
from ...config.config import global_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
from peewee import Case
install(extra_lines=3)
@@ -816,7 +815,7 @@ class EntorhinalCortex:
timestamps = sample_scheduler.get_timestamp_array()
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
for timestamp, readable_timestamp in zip(timestamps, readable_timestamps):
for _, readable_timestamp in zip(timestamps, readable_timestamps):
logger.debug(f"回忆往事: {readable_timestamp}")
chat_samples = []
for timestamp in timestamps:
@@ -852,7 +851,11 @@ class EntorhinalCortex:
chat_id = chosen_message[0].get("chat_id")
messages = get_raw_msg_by_timestamp_with_chat(
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest", chat_id=chat_id
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
limit=chat_size,
limit_mode="earliest",
chat_id=chat_id,
)
if messages:
@@ -943,13 +946,15 @@ class EntorhinalCortex:
if concept not in db_nodes:
# 数据库中缺少的节点,添加到创建列表
nodes_to_create.append({
'concept': concept,
'memory_items': memory_items_json,
'hash': memory_hash,
'created_time': created_time,
'last_modified': last_modified
})
nodes_to_create.append(
{
"concept": concept,
"memory_items": memory_items_json,
"hash": memory_hash,
"created_time": created_time,
"last_modified": last_modified,
}
)
logger.debug(f"[同步] 准备创建节点: {concept}, memory_items长度: {len(memory_items)}")
else:
# 获取数据库中节点的特征值
@@ -958,12 +963,14 @@ class EntorhinalCortex:
# 如果特征值不同,则添加到更新列表
if db_hash != memory_hash:
nodes_to_update.append({
'concept': concept,
'memory_items': memory_items_json,
'hash': memory_hash,
'last_modified': last_modified
})
nodes_to_update.append(
{
"concept": concept,
"memory_items": memory_items_json,
"hash": memory_hash,
"last_modified": last_modified,
}
)
# 检查需要删除的节点
memory_concepts = {concept for concept, _ in memory_nodes}
@@ -972,7 +979,9 @@ class EntorhinalCortex:
node_process_end = time.time()
logger.info(f"[同步] 处理节点数据耗时: {node_process_end - node_process_start:.2f}")
logger.info(f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点")
logger.info(
f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点"
)
# 异步批量创建新节点
node_create_start = time.time()
@@ -981,22 +990,24 @@ class EntorhinalCortex:
# 验证所有要创建的节点数据
valid_nodes_to_create = []
for node_data in nodes_to_create:
if not node_data.get('memory_items'):
if not node_data.get("memory_items"):
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空")
continue
try:
# 验证 JSON 字符串
json.loads(node_data['memory_items'])
json.loads(node_data["memory_items"])
valid_nodes_to_create.append(node_data)
except json.JSONDecodeError:
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
logger.warning(
f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串"
)
continue
if valid_nodes_to_create:
# 使用异步批量插入
batch_size = 100
for i in range(0, len(valid_nodes_to_create), batch_size):
batch = valid_nodes_to_create[i:i + batch_size]
batch = valid_nodes_to_create[i : i + batch_size]
await self._async_batch_create_nodes(batch)
logger.info(f"[同步] 成功创建 {len(valid_nodes_to_create)} 个节点")
else:
@@ -1006,21 +1017,25 @@ class EntorhinalCortex:
# 尝试逐个创建以找出问题节点
for node_data in nodes_to_create:
try:
if not node_data.get('memory_items'):
if not node_data.get("memory_items"):
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空")
continue
try:
json.loads(node_data['memory_items'])
json.loads(node_data["memory_items"])
except json.JSONDecodeError:
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
logger.warning(
f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串"
)
continue
await self._async_create_node(node_data)
except Exception as e:
logger.error(f"[同步] 创建节点失败: {node_data['concept']}, 错误: {e}")
# 从图中移除问题节点
self.memory_graph.G.remove_node(node_data['concept'])
self.memory_graph.G.remove_node(node_data["concept"])
node_create_end = time.time()
logger.info(f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)")
logger.info(
f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)"
)
# 异步批量更新节点
node_update_start = time.time()
@@ -1028,30 +1043,32 @@ class EntorhinalCortex:
# 按批次更新节点每批100个
batch_size = 100
for i in range(0, len(nodes_to_update), batch_size):
batch = nodes_to_update[i:i + batch_size]
batch = nodes_to_update[i : i + batch_size]
try:
# 验证批次中的每个节点数据
valid_batch = []
for node_data in batch:
# 确保 memory_items 不为空且是有效的 JSON 字符串
if not node_data.get('memory_items'):
if not node_data.get("memory_items"):
logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 为空")
continue
try:
# 验证 JSON 字符串是否有效
json.loads(node_data['memory_items'])
json.loads(node_data["memory_items"])
valid_batch.append(node_data)
except json.JSONDecodeError:
logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
logger.warning(
f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串"
)
continue
if not valid_batch:
logger.warning(f"[同步] 批次 {i//batch_size + 1} 没有有效的节点可以更新")
logger.warning(f"[同步] 批次 {i // batch_size + 1} 没有有效的节点可以更新")
continue
# 异步批量更新节点
await self._async_batch_update_nodes(valid_batch)
logger.debug(f"[同步] 成功更新批次 {i//batch_size + 1} 中的 {len(valid_batch)} 个节点")
logger.debug(f"[同步] 成功更新批次 {i // batch_size + 1} 中的 {len(valid_batch)} 个节点")
except Exception as e:
logger.error(f"[同步] 批量更新节点失败: {e}")
# 如果批量更新失败,尝试逐个更新
@@ -1061,17 +1078,21 @@ class EntorhinalCortex:
except Exception as e:
logger.error(f"[同步] 更新节点失败: {node_data['concept']}, 错误: {e}")
# 从图中移除问题节点
self.memory_graph.G.remove_node(node_data['concept'])
self.memory_graph.G.remove_node(node_data["concept"])
node_update_end = time.time()
logger.info(f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)")
logger.info(
f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)"
)
# 异步删除不存在的节点
node_delete_start = time.time()
if nodes_to_delete:
await self._async_delete_nodes(nodes_to_delete)
node_delete_end = time.time()
logger.info(f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)")
logger.info(
f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)"
)
# 处理边的信息
edge_load_start = time.time()
@@ -1106,24 +1127,28 @@ class EntorhinalCortex:
if edge_key not in db_edge_dict:
# 添加新边到创建列表
edges_to_create.append({
'source': source,
'target': target,
'strength': strength,
'hash': edge_hash,
'created_time': created_time,
'last_modified': last_modified
})
edges_to_create.append(
{
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"created_time": created_time,
"last_modified": last_modified,
}
)
else:
# 检查边的特征值是否变化
if db_edge_dict[edge_key]["hash"] != edge_hash:
edges_to_update.append({
'source': source,
'target': target,
'strength': strength,
'hash': edge_hash,
'last_modified': last_modified
})
edges_to_update.append(
{
"source": source,
"target": target,
"strength": strength,
"hash": edge_hash,
"last_modified": last_modified,
}
)
edge_process_end = time.time()
logger.info(f"[同步] 处理边数据耗时: {edge_process_end - edge_process_start:.2f}")
@@ -1132,20 +1157,24 @@ class EntorhinalCortex:
if edges_to_create:
batch_size = 100
for i in range(0, len(edges_to_create), batch_size):
batch = edges_to_create[i:i + batch_size]
batch = edges_to_create[i : i + batch_size]
await self._async_batch_create_edges(batch)
edge_create_end = time.time()
logger.info(f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)")
logger.info(
f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)"
)
# 异步批量更新边
edge_update_start = time.time()
if edges_to_update:
batch_size = 100
for i in range(0, len(edges_to_update), batch_size):
batch = edges_to_update[i:i + batch_size]
batch = edges_to_update[i : i + batch_size]
await self._async_batch_update_edges(batch)
edge_update_end = time.time()
logger.info(f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)")
logger.info(
f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)"
)
# 检查需要删除的边
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
@@ -1157,7 +1186,9 @@ class EntorhinalCortex:
if edges_to_delete:
await self._async_delete_edges(edges_to_delete)
edge_delete_end = time.time()
logger.info(f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)")
logger.info(
f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)"
)
end_time = time.time()
logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}")
@@ -1183,8 +1214,8 @@ class EntorhinalCortex:
"""异步批量更新节点"""
try:
for node_data in nodes_data:
GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where(
GraphNodes.concept == node_data['concept']
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
GraphNodes.concept == node_data["concept"]
).execute()
except Exception as e:
logger.error(f"[同步] 批量更新节点失败: {e}")
@@ -1193,8 +1224,8 @@ class EntorhinalCortex:
async def _async_update_node(self, node_data):
"""异步更新单个节点"""
try:
GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where(
GraphNodes.concept == node_data['concept']
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
GraphNodes.concept == node_data["concept"]
).execute()
except Exception as e:
logger.error(f"[同步] 更新节点失败: {e}")
@@ -1220,9 +1251,8 @@ class EntorhinalCortex:
"""异步批量更新边"""
try:
for edge_data in edges_data:
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ['source', 'target']}).where(
(GraphEdges.source == edge_data['source']) &
(GraphEdges.target == edge_data['target'])
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
).execute()
except Exception as e:
logger.error(f"[同步] 批量更新边失败: {e}")
@@ -1232,10 +1262,7 @@ class EntorhinalCortex:
"""异步删除边"""
try:
for source, target in edge_keys:
GraphEdges.delete().where(
(GraphEdges.source == source) &
(GraphEdges.target == target)
).execute()
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
except Exception as e:
logger.error(f"[同步] 删除边失败: {e}")
raise
@@ -1534,7 +1561,6 @@ class ParahippocampalGyrus:
all_added_edges.append(f"{topic1}-{topic2}")
self.memory_graph.connect_dot(topic1, topic2)
progress = (i / len(memory_samples)) * 100
bar_length = 30
filled_length = int(bar_length * i // len(memory_samples))

View File

@@ -243,9 +243,7 @@ class NormalChat:
self.interest_dict.pop(msg_id, None)
# 改为实例方法, 移除 chat 参数
async def normal_response(
self, message: MessageRecv, is_mentioned: bool, interested_rate: float
) -> None:
async def normal_response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
# 新增:如果已停用,直接返回
if self._disabled:
logger.info(f"[{self.stream_name}] 已停用,忽略 normal_response。")
@@ -287,7 +285,6 @@ class NormalChat:
# 回复前处理
await willing_manager.before_generate_reply_handle(message.message_info.message_id)
thinking_id = await self._create_thinking_message(message)
logger.debug(f"[{self.stream_name}] 创建捕捉器thinking_id:{thinking_id}")
@@ -362,6 +359,9 @@ class NormalChat:
if action_type == "no_action":
logger.debug(f"[{self.stream_name}] Planner决定不执行任何额外动作")
return None
elif action_type == "change_to_focus_chat":
logger.info(f"[{self.stream_name}] Planner决定切换到focus聊天模式")
return None
# 执行额外的动作(不影响回复生成)
action_result = await self._execute_action(action_type, action_data, message, thinking_id)
@@ -396,7 +396,9 @@ class NormalChat:
elif plan_result:
logger.debug(f"[{self.stream_name}] 额外动作处理完成: {plan_result['action_type']}")
if not response_set or (self.enable_planner and self.action_type != "no_action"):
if not response_set or (
self.enable_planner and self.action_type not in ["no_action", "change_to_focus_chat"]
):
logger.info(f"[{self.stream_name}] 模型未生成回复内容")
# 如果模型未生成回复,移除思考消息
container = await message_manager.get_container(self.stream_id) # 使用 self.stream_id
@@ -445,6 +447,14 @@ class NormalChat:
# 检查是否需要切换到focus模式
if global_config.chat.chat_mode == "auto":
if self.action_type == "change_to_focus_chat":
logger.info(f"[{self.stream_name}] 检测到切换到focus聊天模式的请求")
if self.on_switch_to_focus_callback:
await self.on_switch_to_focus_callback()
else:
logger.warning(f"[{self.stream_name}] 没有设置切换到focus聊天模式的回调函数无法执行切换")
return
else:
await self._check_switch_to_focus()
info_catcher.done_catch()

View File

@@ -28,6 +28,7 @@ def init_prompt():
重要说明:
- "no_action" 表示只进行普通聊天回复,不执行任何额外动作
- "change_to_focus_chat" 表示当聊天变得热烈、自己回复条数很多或需要深入交流时正常回复消息并切换到focus_chat模式进行更深入的对话
- 其他action表示在普通回复的基础上执行相应的额外动作
你必须从上面列出的可用action中选择一个并说明原因。
@@ -156,8 +157,8 @@ class NormalChatPlanner:
# 提取其他参数作为action_data
action_data = {k: v for k, v in action_result.items() if k not in ["action", "reasoning"]}
# 验证动作是否在可用动作列表中
if action not in current_available_actions:
# 验证动作是否在可用动作列表中,或者是特殊动作
if action not in current_available_actions and action != "change_to_focus_chat":
logger.warning(f"{self.log_prefix}规划器选择了不可用的动作: {action}, 回退到no_action")
action = "no_action"
reasoning = f"选择的动作{action}不在可用列表中回退到no_action"
@@ -211,6 +212,19 @@ class NormalChatPlanner:
try:
# 构建动作选项文本
action_options_text = ""
# 添加特殊的change_to_focus_chat动作
action_options_text += "action_name: change_to_focus_chat\n"
action_options_text += (
" 描述当聊天变得热烈、自己回复条数很多或需要深入交流时使用正常回复消息并切换到focus_chat模式\n"
)
action_options_text += " 参数:\n"
action_options_text += " 动作要求:\n"
action_options_text += " - 聊天上下文中自己的回复条数较多超过3-4条\n"
action_options_text += " - 对话进行得非常热烈活跃\n"
action_options_text += " - 用户表现出深入交流的意图\n"
action_options_text += " - 话题需要更专注和深入的讨论\n\n"
for action_name, action_info in current_available_actions.items():
action_description = action_info.get("description", "")
action_parameters = action_info.get("parameters", {})

View File

@@ -142,7 +142,9 @@ class TelemetryHeartBeatTask(AsyncTask):
del local_storage["mmc_uuid"] # 删除本地存储的UUID
else:
# 其他错误
logger.warning(f"(此错误不会影响正常使用)状态未发送,状态码: {response.status_code}, 响应内容: {response.text}")
logger.warning(
f"(此错误不会影响正常使用)状态未发送,状态码: {response.status_code}, 响应内容: {response.text}"
)
async def run(self):
# 发送心跳

View File

@@ -104,7 +104,13 @@ class PersonalityExpression:
if count >= self.max_calculations:
logger.debug(f"对于风格 '{current_style_text}' 已达到最大计算次数 ({self.max_calculations})。跳过提取。")
# 即使跳过,也更新元数据以反映当前风格已被识别且计数已满
self._write_meta_data({"last_style_text": current_style_text, "count": count, "last_update_time": meta_data.get("last_update_time")})
self._write_meta_data(
{
"last_style_text": current_style_text,
"count": count,
"last_update_time": meta_data.get("last_update_time"),
}
)
return
# 构建prompt
@@ -119,7 +125,13 @@ class PersonalityExpression:
except Exception as e:
logger.error(f"个性表达方式提取失败: {e}")
# 如果提取失败,保存当前的风格和未增加的计数
self._write_meta_data({"last_style_text": current_style_text, "count": count, "last_update_time": meta_data.get("last_update_time")})
self._write_meta_data(
{
"last_style_text": current_style_text,
"count": count,
"last_update_time": meta_data.get("last_update_time"),
}
)
return
logger.info(f"个性表达方式提取response: {response}")
@@ -147,8 +159,10 @@ class PersonalityExpression:
for new_expr in new_expressions:
found = False
for existing_expr in merged_expressions:
if (existing_expr["situation"] == new_expr["situation"] and
existing_expr["style"] == new_expr["style"]):
if (
existing_expr["situation"] == new_expr["situation"]
and existing_expr["style"] == new_expr["style"]
):
existing_expr["count"] += new_expr["count"]
found = True
break
@@ -168,11 +182,9 @@ class PersonalityExpression:
# 成功提取后更新元数据
count += 1
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self._write_meta_data({
"last_style_text": current_style_text,
"count": count,
"last_update_time": current_time
})
self._write_meta_data(
{"last_style_text": current_style_text, "count": count, "last_update_time": current_time}
)
logger.info(f"成功处理。风格 '{current_style_text}' 的计数现在是 {count},最后更新时间:{current_time}")
else:
logger.warning(f"个性表达方式提取失败,模型返回空内容: {response}")