fix: ruff
This commit is contained in:
@@ -1,37 +1,38 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
import networkx as nx
|
|
||||||
import matplotlib as mpl
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
# 设置中文字体
|
# 设置中文字体
|
||||||
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 使用微软雅黑
|
plt.rcParams["font.sans-serif"] = ["Microsoft YaHei"] # 使用微软雅黑
|
||||||
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
plt.rcParams["axes.unicode_minus"] = False # 用来正常显示负号
|
||||||
plt.rcParams['font.family'] = 'sans-serif'
|
plt.rcParams["font.family"] = "sans-serif"
|
||||||
|
|
||||||
# 获取脚本所在目录
|
# 获取脚本所在目录
|
||||||
SCRIPT_DIR = Path(__file__).parent
|
SCRIPT_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id):
|
def get_group_name(stream_id):
|
||||||
"""从数据库中获取群组名称"""
|
"""从数据库中获取群组名称"""
|
||||||
conn = sqlite3.connect('data/maibot.db')
|
conn = sqlite3.connect("data/maibot.db")
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute('''
|
cursor.execute(
|
||||||
|
"""
|
||||||
SELECT group_name, user_nickname, platform
|
SELECT group_name, user_nickname, platform
|
||||||
FROM chat_streams
|
FROM chat_streams
|
||||||
WHERE stream_id = ?
|
WHERE stream_id = ?
|
||||||
''', (stream_id,))
|
""",
|
||||||
|
(stream_id,),
|
||||||
|
)
|
||||||
|
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
group_name, user_nickname, platform = result
|
group_name, user_nickname, platform = result
|
||||||
if group_name:
|
if group_name:
|
||||||
@@ -42,139 +43,148 @@ def get_group_name(stream_id):
|
|||||||
return f"{platform}-{stream_id[:8]}"
|
return f"{platform}-{stream_id[:8]}"
|
||||||
return stream_id
|
return stream_id
|
||||||
|
|
||||||
|
|
||||||
def load_group_data(group_dir):
|
def load_group_data(group_dir):
|
||||||
"""加载单个群组的数据"""
|
"""加载单个群组的数据"""
|
||||||
json_path = Path(group_dir) / "expressions.json"
|
json_path = Path(group_dir) / "expressions.json"
|
||||||
if not json_path.exists():
|
if not json_path.exists():
|
||||||
return [], [], []
|
return [], [], []
|
||||||
|
|
||||||
with open(json_path, 'r', encoding='utf-8') as f:
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
situations = []
|
situations = []
|
||||||
styles = []
|
styles = []
|
||||||
combined = []
|
combined = []
|
||||||
|
|
||||||
for item in data:
|
for item in data:
|
||||||
count = item['count']
|
count = item["count"]
|
||||||
situations.extend([item['situation']] * count)
|
situations.extend([item["situation"]] * count)
|
||||||
styles.extend([item['style']] * count)
|
styles.extend([item["style"]] * count)
|
||||||
combined.extend([f"{item['situation']} {item['style']}"] * count)
|
combined.extend([f"{item['situation']} {item['style']}"] * count)
|
||||||
|
|
||||||
return situations, styles, combined
|
return situations, styles, combined
|
||||||
|
|
||||||
|
|
||||||
def analyze_group_similarity():
|
def analyze_group_similarity():
|
||||||
# 获取所有群组目录
|
# 获取所有群组目录
|
||||||
base_dir = Path("data/expression/learnt_style")
|
base_dir = Path("data/expression/learnt_style")
|
||||||
group_dirs = [d for d in base_dir.iterdir() if d.is_dir()]
|
group_dirs = [d for d in base_dir.iterdir() if d.is_dir()]
|
||||||
group_ids = [d.name for d in group_dirs]
|
group_ids = [d.name for d in group_dirs]
|
||||||
|
|
||||||
# 获取群组名称
|
# 获取群组名称
|
||||||
group_names = [get_group_name(group_id) for group_id in group_ids]
|
group_names = [get_group_name(group_id) for group_id in group_ids]
|
||||||
|
|
||||||
# 加载所有群组的数据
|
# 加载所有群组的数据
|
||||||
group_situations = []
|
group_situations = []
|
||||||
group_styles = []
|
group_styles = []
|
||||||
group_combined = []
|
group_combined = []
|
||||||
|
|
||||||
for d in group_dirs:
|
for d in group_dirs:
|
||||||
situations, styles, combined = load_group_data(d)
|
situations, styles, combined = load_group_data(d)
|
||||||
group_situations.append(' '.join(situations))
|
group_situations.append(" ".join(situations))
|
||||||
group_styles.append(' '.join(styles))
|
group_styles.append(" ".join(styles))
|
||||||
group_combined.append(' '.join(combined))
|
group_combined.append(" ".join(combined))
|
||||||
|
|
||||||
# 创建TF-IDF向量化器
|
# 创建TF-IDF向量化器
|
||||||
vectorizer = TfidfVectorizer()
|
vectorizer = TfidfVectorizer()
|
||||||
|
|
||||||
# 计算三种相似度矩阵
|
# 计算三种相似度矩阵
|
||||||
situation_matrix = cosine_similarity(vectorizer.fit_transform(group_situations))
|
situation_matrix = cosine_similarity(vectorizer.fit_transform(group_situations))
|
||||||
style_matrix = cosine_similarity(vectorizer.fit_transform(group_styles))
|
style_matrix = cosine_similarity(vectorizer.fit_transform(group_styles))
|
||||||
combined_matrix = cosine_similarity(vectorizer.fit_transform(group_combined))
|
combined_matrix = cosine_similarity(vectorizer.fit_transform(group_combined))
|
||||||
|
|
||||||
# 对相似度矩阵进行对数变换
|
# 对相似度矩阵进行对数变换
|
||||||
log_situation_matrix = np.log1p(situation_matrix)
|
log_situation_matrix = np.log1p(situation_matrix)
|
||||||
log_style_matrix = np.log1p(style_matrix)
|
log_style_matrix = np.log1p(style_matrix)
|
||||||
log_combined_matrix = np.log1p(combined_matrix)
|
log_combined_matrix = np.log1p(combined_matrix)
|
||||||
|
|
||||||
# 创建一个大图,包含三个子图
|
# 创建一个大图,包含三个子图
|
||||||
plt.figure(figsize=(45, 12))
|
plt.figure(figsize=(45, 12))
|
||||||
|
|
||||||
# 场景相似度热力图
|
# 场景相似度热力图
|
||||||
plt.subplot(1, 3, 1)
|
plt.subplot(1, 3, 1)
|
||||||
sns.heatmap(log_situation_matrix,
|
sns.heatmap(
|
||||||
xticklabels=group_names,
|
log_situation_matrix,
|
||||||
yticklabels=group_names,
|
xticklabels=group_names,
|
||||||
cmap='YlOrRd',
|
yticklabels=group_names,
|
||||||
annot=True,
|
cmap="YlOrRd",
|
||||||
fmt='.2f',
|
annot=True,
|
||||||
vmin=0,
|
fmt=".2f",
|
||||||
vmax=np.log1p(0.2))
|
vmin=0,
|
||||||
plt.title('群组场景相似度热力图 (对数变换)')
|
vmax=np.log1p(0.2),
|
||||||
plt.xticks(rotation=45, ha='right')
|
)
|
||||||
|
plt.title("群组场景相似度热力图 (对数变换)")
|
||||||
|
plt.xticks(rotation=45, ha="right")
|
||||||
|
|
||||||
# 表达方式相似度热力图
|
# 表达方式相似度热力图
|
||||||
plt.subplot(1, 3, 2)
|
plt.subplot(1, 3, 2)
|
||||||
sns.heatmap(log_style_matrix,
|
sns.heatmap(
|
||||||
xticklabels=group_names,
|
log_style_matrix,
|
||||||
yticklabels=group_names,
|
xticklabels=group_names,
|
||||||
cmap='YlOrRd',
|
yticklabels=group_names,
|
||||||
annot=True,
|
cmap="YlOrRd",
|
||||||
fmt='.2f',
|
annot=True,
|
||||||
vmin=0,
|
fmt=".2f",
|
||||||
vmax=np.log1p(0.2))
|
vmin=0,
|
||||||
plt.title('群组表达方式相似度热力图 (对数变换)')
|
vmax=np.log1p(0.2),
|
||||||
plt.xticks(rotation=45, ha='right')
|
)
|
||||||
|
plt.title("群组表达方式相似度热力图 (对数变换)")
|
||||||
|
plt.xticks(rotation=45, ha="right")
|
||||||
|
|
||||||
# 组合相似度热力图
|
# 组合相似度热力图
|
||||||
plt.subplot(1, 3, 3)
|
plt.subplot(1, 3, 3)
|
||||||
sns.heatmap(log_combined_matrix,
|
sns.heatmap(
|
||||||
xticklabels=group_names,
|
log_combined_matrix,
|
||||||
yticklabels=group_names,
|
xticklabels=group_names,
|
||||||
cmap='YlOrRd',
|
yticklabels=group_names,
|
||||||
annot=True,
|
cmap="YlOrRd",
|
||||||
fmt='.2f',
|
annot=True,
|
||||||
vmin=0,
|
fmt=".2f",
|
||||||
vmax=np.log1p(0.2))
|
vmin=0,
|
||||||
plt.title('群组场景+表达方式相似度热力图 (对数变换)')
|
vmax=np.log1p(0.2),
|
||||||
plt.xticks(rotation=45, ha='right')
|
)
|
||||||
|
plt.title("群组场景+表达方式相似度热力图 (对数变换)")
|
||||||
|
plt.xticks(rotation=45, ha="right")
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(SCRIPT_DIR / 'group_similarity_heatmaps.png', dpi=300, bbox_inches='tight')
|
plt.savefig(SCRIPT_DIR / "group_similarity_heatmaps.png", dpi=300, bbox_inches="tight")
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
# 保存匹配详情到文本文件
|
# 保存匹配详情到文本文件
|
||||||
with open(SCRIPT_DIR / 'group_similarity_details.txt', 'w', encoding='utf-8') as f:
|
with open(SCRIPT_DIR / "group_similarity_details.txt", "w", encoding="utf-8") as f:
|
||||||
f.write('群组相似度详情\n')
|
f.write("群组相似度详情\n")
|
||||||
f.write('=' * 50 + '\n\n')
|
f.write("=" * 50 + "\n\n")
|
||||||
|
|
||||||
for i in range(len(group_ids)):
|
for i in range(len(group_ids)):
|
||||||
for j in range(i+1, len(group_ids)):
|
for j in range(i + 1, len(group_ids)):
|
||||||
if log_combined_matrix[i][j] > np.log1p(0.05):
|
if log_combined_matrix[i][j] > np.log1p(0.05):
|
||||||
f.write(f'群组1: {group_names[i]}\n')
|
f.write(f"群组1: {group_names[i]}\n")
|
||||||
f.write(f'群组2: {group_names[j]}\n')
|
f.write(f"群组2: {group_names[j]}\n")
|
||||||
f.write(f'场景相似度: {situation_matrix[i][j]:.4f}\n')
|
f.write(f"场景相似度: {situation_matrix[i][j]:.4f}\n")
|
||||||
f.write(f'表达方式相似度: {style_matrix[i][j]:.4f}\n')
|
f.write(f"表达方式相似度: {style_matrix[i][j]:.4f}\n")
|
||||||
f.write(f'组合相似度: {combined_matrix[i][j]:.4f}\n')
|
f.write(f"组合相似度: {combined_matrix[i][j]:.4f}\n")
|
||||||
|
|
||||||
# 获取两个群组的数据
|
# 获取两个群组的数据
|
||||||
situations1, styles1, _ = load_group_data(group_dirs[i])
|
situations1, styles1, _ = load_group_data(group_dirs[i])
|
||||||
situations2, styles2, _ = load_group_data(group_dirs[j])
|
situations2, styles2, _ = load_group_data(group_dirs[j])
|
||||||
|
|
||||||
# 找出共同的场景
|
# 找出共同的场景
|
||||||
common_situations = set(situations1) & set(situations2)
|
common_situations = set(situations1) & set(situations2)
|
||||||
if common_situations:
|
if common_situations:
|
||||||
f.write('\n共同场景:\n')
|
f.write("\n共同场景:\n")
|
||||||
for situation in common_situations:
|
for situation in common_situations:
|
||||||
f.write(f'- {situation}\n')
|
f.write(f"- {situation}\n")
|
||||||
|
|
||||||
# 找出共同的表达方式
|
# 找出共同的表达方式
|
||||||
common_styles = set(styles1) & set(styles2)
|
common_styles = set(styles1) & set(styles2)
|
||||||
if common_styles:
|
if common_styles:
|
||||||
f.write('\n共同表达方式:\n')
|
f.write("\n共同表达方式:\n")
|
||||||
for style in common_styles:
|
for style in common_styles:
|
||||||
f.write(f'- {style}\n')
|
f.write(f"- {style}\n")
|
||||||
|
|
||||||
f.write('\n' + '-' * 50 + '\n\n')
|
f.write("\n" + "-" * 50 + "\n\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
analyze_group_similarity()
|
analyze_group_similarity()
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ from rich.panel import Panel
|
|||||||
from src.common.database.database import db
|
from src.common.database.database import db
|
||||||
from src.common.database.database_model import (
|
from src.common.database.database_model import (
|
||||||
ChatStreams,
|
ChatStreams,
|
||||||
LLMUsage,
|
|
||||||
Emoji,
|
Emoji,
|
||||||
Messages,
|
Messages,
|
||||||
Images,
|
Images,
|
||||||
|
|||||||
@@ -8,162 +8,174 @@ import matplotlib.pyplot as plt
|
|||||||
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
import numpy as np
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
class ExpressionViewer:
|
class ExpressionViewer:
|
||||||
def __init__(self, root):
|
def __init__(self, root):
|
||||||
self.root = root
|
self.root = root
|
||||||
self.root.title("表达方式预览器")
|
self.root.title("表达方式预览器")
|
||||||
self.root.geometry("1200x800")
|
self.root.geometry("1200x800")
|
||||||
|
|
||||||
# 创建主框架
|
# 创建主框架
|
||||||
self.main_frame = ttk.Frame(root)
|
self.main_frame = ttk.Frame(root)
|
||||||
self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
self.main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
|
||||||
|
|
||||||
# 创建左侧控制面板
|
# 创建左侧控制面板
|
||||||
self.control_frame = ttk.Frame(self.main_frame)
|
self.control_frame = ttk.Frame(self.main_frame)
|
||||||
self.control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 10))
|
self.control_frame.pack(side=tk.LEFT, fill=tk.Y, padx=(0, 10))
|
||||||
|
|
||||||
# 创建搜索框
|
# 创建搜索框
|
||||||
self.search_frame = ttk.Frame(self.control_frame)
|
self.search_frame = ttk.Frame(self.control_frame)
|
||||||
self.search_frame.pack(fill=tk.X, pady=(0, 10))
|
self.search_frame.pack(fill=tk.X, pady=(0, 10))
|
||||||
|
|
||||||
self.search_var = tk.StringVar()
|
self.search_var = tk.StringVar()
|
||||||
self.search_var.trace('w', self.filter_expressions)
|
self.search_var.trace("w", self.filter_expressions)
|
||||||
self.search_entry = ttk.Entry(self.search_frame, textvariable=self.search_var)
|
self.search_entry = ttk.Entry(self.search_frame, textvariable=self.search_var)
|
||||||
self.search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True)
|
self.search_entry.pack(side=tk.LEFT, fill=tk.X, expand=True)
|
||||||
ttk.Label(self.search_frame, text="搜索:").pack(side=tk.LEFT, padx=(0, 5))
|
ttk.Label(self.search_frame, text="搜索:").pack(side=tk.LEFT, padx=(0, 5))
|
||||||
|
|
||||||
# 创建文件选择下拉框
|
# 创建文件选择下拉框
|
||||||
self.file_var = tk.StringVar()
|
self.file_var = tk.StringVar()
|
||||||
self.file_combo = ttk.Combobox(self.search_frame, textvariable=self.file_var)
|
self.file_combo = ttk.Combobox(self.search_frame, textvariable=self.file_var)
|
||||||
self.file_combo.pack(side=tk.LEFT, padx=5)
|
self.file_combo.pack(side=tk.LEFT, padx=5)
|
||||||
self.file_combo.bind('<<ComboboxSelected>>', self.load_file)
|
self.file_combo.bind("<<ComboboxSelected>>", self.load_file)
|
||||||
|
|
||||||
# 创建排序选项
|
# 创建排序选项
|
||||||
self.sort_frame = ttk.LabelFrame(self.control_frame, text="排序选项")
|
self.sort_frame = ttk.LabelFrame(self.control_frame, text="排序选项")
|
||||||
self.sort_frame.pack(fill=tk.X, pady=5)
|
self.sort_frame.pack(fill=tk.X, pady=5)
|
||||||
|
|
||||||
self.sort_var = tk.StringVar(value="count")
|
self.sort_var = tk.StringVar(value="count")
|
||||||
ttk.Radiobutton(self.sort_frame, text="按计数排序", variable=self.sort_var,
|
ttk.Radiobutton(
|
||||||
value="count", command=self.apply_sort).pack(anchor=tk.W)
|
self.sort_frame, text="按计数排序", variable=self.sort_var, value="count", command=self.apply_sort
|
||||||
ttk.Radiobutton(self.sort_frame, text="按情境排序", variable=self.sort_var,
|
).pack(anchor=tk.W)
|
||||||
value="situation", command=self.apply_sort).pack(anchor=tk.W)
|
ttk.Radiobutton(
|
||||||
ttk.Radiobutton(self.sort_frame, text="按风格排序", variable=self.sort_var,
|
self.sort_frame, text="按情境排序", variable=self.sort_var, value="situation", command=self.apply_sort
|
||||||
value="style", command=self.apply_sort).pack(anchor=tk.W)
|
).pack(anchor=tk.W)
|
||||||
|
ttk.Radiobutton(
|
||||||
|
self.sort_frame, text="按风格排序", variable=self.sort_var, value="style", command=self.apply_sort
|
||||||
|
).pack(anchor=tk.W)
|
||||||
|
|
||||||
# 创建分群选项
|
# 创建分群选项
|
||||||
self.group_frame = ttk.LabelFrame(self.control_frame, text="分群选项")
|
self.group_frame = ttk.LabelFrame(self.control_frame, text="分群选项")
|
||||||
self.group_frame.pack(fill=tk.X, pady=5)
|
self.group_frame.pack(fill=tk.X, pady=5)
|
||||||
|
|
||||||
self.group_var = tk.StringVar(value="none")
|
self.group_var = tk.StringVar(value="none")
|
||||||
ttk.Radiobutton(self.group_frame, text="不分群", variable=self.group_var,
|
ttk.Radiobutton(
|
||||||
value="none", command=self.apply_grouping).pack(anchor=tk.W)
|
self.group_frame, text="不分群", variable=self.group_var, value="none", command=self.apply_grouping
|
||||||
ttk.Radiobutton(self.group_frame, text="按情境分群", variable=self.group_var,
|
).pack(anchor=tk.W)
|
||||||
value="situation", command=self.apply_grouping).pack(anchor=tk.W)
|
ttk.Radiobutton(
|
||||||
ttk.Radiobutton(self.group_frame, text="按风格分群", variable=self.group_var,
|
self.group_frame, text="按情境分群", variable=self.group_var, value="situation", command=self.apply_grouping
|
||||||
value="style", command=self.apply_grouping).pack(anchor=tk.W)
|
).pack(anchor=tk.W)
|
||||||
|
ttk.Radiobutton(
|
||||||
|
self.group_frame, text="按风格分群", variable=self.group_var, value="style", command=self.apply_grouping
|
||||||
|
).pack(anchor=tk.W)
|
||||||
|
|
||||||
# 创建相似度阈值滑块
|
# 创建相似度阈值滑块
|
||||||
self.similarity_frame = ttk.LabelFrame(self.control_frame, text="相似度设置")
|
self.similarity_frame = ttk.LabelFrame(self.control_frame, text="相似度设置")
|
||||||
self.similarity_frame.pack(fill=tk.X, pady=5)
|
self.similarity_frame.pack(fill=tk.X, pady=5)
|
||||||
|
|
||||||
self.similarity_var = tk.DoubleVar(value=0.5)
|
self.similarity_var = tk.DoubleVar(value=0.5)
|
||||||
self.similarity_scale = ttk.Scale(self.similarity_frame, from_=0.0, to=1.0,
|
self.similarity_scale = ttk.Scale(
|
||||||
variable=self.similarity_var, orient=tk.HORIZONTAL,
|
self.similarity_frame,
|
||||||
command=self.update_similarity)
|
from_=0.0,
|
||||||
|
to=1.0,
|
||||||
|
variable=self.similarity_var,
|
||||||
|
orient=tk.HORIZONTAL,
|
||||||
|
command=self.update_similarity,
|
||||||
|
)
|
||||||
self.similarity_scale.pack(fill=tk.X, padx=5, pady=5)
|
self.similarity_scale.pack(fill=tk.X, padx=5, pady=5)
|
||||||
ttk.Label(self.similarity_frame, text="相似度阈值: 0.5").pack()
|
ttk.Label(self.similarity_frame, text="相似度阈值: 0.5").pack()
|
||||||
|
|
||||||
# 创建显示选项
|
# 创建显示选项
|
||||||
self.view_frame = ttk.LabelFrame(self.control_frame, text="显示选项")
|
self.view_frame = ttk.LabelFrame(self.control_frame, text="显示选项")
|
||||||
self.view_frame.pack(fill=tk.X, pady=5)
|
self.view_frame.pack(fill=tk.X, pady=5)
|
||||||
|
|
||||||
self.show_graph_var = tk.BooleanVar(value=True)
|
self.show_graph_var = tk.BooleanVar(value=True)
|
||||||
ttk.Checkbutton(self.view_frame, text="显示关系图", variable=self.show_graph_var,
|
ttk.Checkbutton(
|
||||||
command=self.toggle_graph).pack(anchor=tk.W)
|
self.view_frame, text="显示关系图", variable=self.show_graph_var, command=self.toggle_graph
|
||||||
|
).pack(anchor=tk.W)
|
||||||
|
|
||||||
# 创建右侧内容区域
|
# 创建右侧内容区域
|
||||||
self.content_frame = ttk.Frame(self.main_frame)
|
self.content_frame = ttk.Frame(self.main_frame)
|
||||||
self.content_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
self.content_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
|
||||||
|
|
||||||
# 创建文本显示区域
|
# 创建文本显示区域
|
||||||
self.text_area = tk.Text(self.content_frame, wrap=tk.WORD)
|
self.text_area = tk.Text(self.content_frame, wrap=tk.WORD)
|
||||||
self.text_area.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
self.text_area.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
||||||
|
|
||||||
# 添加滚动条
|
# 添加滚动条
|
||||||
scrollbar = ttk.Scrollbar(self.text_area, command=self.text_area.yview)
|
scrollbar = ttk.Scrollbar(self.text_area, command=self.text_area.yview)
|
||||||
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
|
||||||
self.text_area.config(yscrollcommand=scrollbar.set)
|
self.text_area.config(yscrollcommand=scrollbar.set)
|
||||||
|
|
||||||
# 创建图形显示区域
|
# 创建图形显示区域
|
||||||
self.graph_frame = ttk.Frame(self.content_frame)
|
self.graph_frame = ttk.Frame(self.content_frame)
|
||||||
self.graph_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
self.graph_frame.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
|
||||||
|
|
||||||
# 初始化数据
|
# 初始化数据
|
||||||
self.current_data = []
|
self.current_data = []
|
||||||
self.graph = nx.Graph()
|
self.graph = nx.Graph()
|
||||||
self.canvas = None
|
self.canvas = None
|
||||||
|
|
||||||
# 加载文件列表
|
# 加载文件列表
|
||||||
self.load_file_list()
|
self.load_file_list()
|
||||||
|
|
||||||
def load_file_list(self):
|
def load_file_list(self):
|
||||||
expression_dir = Path("data/expression")
|
expression_dir = Path("data/expression")
|
||||||
files = []
|
files = []
|
||||||
for root, _, filenames in os.walk(expression_dir):
|
for root, _, filenames in os.walk(expression_dir):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
if filename.endswith('.json'):
|
if filename.endswith(".json"):
|
||||||
rel_path = os.path.relpath(os.path.join(root, filename), expression_dir)
|
rel_path = os.path.relpath(os.path.join(root, filename), expression_dir)
|
||||||
files.append(rel_path)
|
files.append(rel_path)
|
||||||
|
|
||||||
self.file_combo['values'] = files
|
self.file_combo["values"] = files
|
||||||
if files:
|
if files:
|
||||||
self.file_combo.set(files[0])
|
self.file_combo.set(files[0])
|
||||||
self.load_file(None)
|
self.load_file(None)
|
||||||
|
|
||||||
def load_file(self, event):
|
def load_file(self, event):
|
||||||
selected_file = self.file_var.get()
|
selected_file = self.file_var.get()
|
||||||
if not selected_file:
|
if not selected_file:
|
||||||
return
|
return
|
||||||
|
|
||||||
file_path = os.path.join("data/expression", selected_file)
|
file_path = os.path.join("data/expression", selected_file)
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
self.current_data = json.load(f)
|
self.current_data = json.load(f)
|
||||||
|
|
||||||
self.apply_sort()
|
self.apply_sort()
|
||||||
self.update_similarity()
|
self.update_similarity()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.text_area.delete(1.0, tk.END)
|
self.text_area.delete(1.0, tk.END)
|
||||||
self.text_area.insert(tk.END, f"加载文件时出错: {str(e)}")
|
self.text_area.insert(tk.END, f"加载文件时出错: {str(e)}")
|
||||||
|
|
||||||
def apply_sort(self):
|
def apply_sort(self):
|
||||||
if not self.current_data:
|
if not self.current_data:
|
||||||
return
|
return
|
||||||
|
|
||||||
sort_key = self.sort_var.get()
|
sort_key = self.sort_var.get()
|
||||||
reverse = sort_key == "count"
|
reverse = sort_key == "count"
|
||||||
|
|
||||||
self.current_data.sort(key=lambda x: x.get(sort_key, ""), reverse=reverse)
|
self.current_data.sort(key=lambda x: x.get(sort_key, ""), reverse=reverse)
|
||||||
self.apply_grouping()
|
self.apply_grouping()
|
||||||
|
|
||||||
def apply_grouping(self):
|
def apply_grouping(self):
|
||||||
if not self.current_data:
|
if not self.current_data:
|
||||||
return
|
return
|
||||||
|
|
||||||
group_key = self.group_var.get()
|
group_key = self.group_var.get()
|
||||||
if group_key == "none":
|
if group_key == "none":
|
||||||
self.display_data(self.current_data)
|
self.display_data(self.current_data)
|
||||||
return
|
return
|
||||||
|
|
||||||
grouped_data = defaultdict(list)
|
grouped_data = defaultdict(list)
|
||||||
for item in self.current_data:
|
for item in self.current_data:
|
||||||
key = item.get(group_key, "未分类")
|
key = item.get(group_key, "未分类")
|
||||||
grouped_data[key].append(item)
|
grouped_data[key].append(item)
|
||||||
|
|
||||||
self.text_area.delete(1.0, tk.END)
|
self.text_area.delete(1.0, tk.END)
|
||||||
for group, items in grouped_data.items():
|
for group, items in grouped_data.items():
|
||||||
self.text_area.insert(tk.END, f"\n=== {group} ===\n\n")
|
self.text_area.insert(tk.END, f"\n=== {group} ===\n\n")
|
||||||
@@ -172,7 +184,7 @@ class ExpressionViewer:
|
|||||||
self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n")
|
self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n")
|
||||||
self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n")
|
self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n")
|
||||||
self.text_area.insert(tk.END, "-" * 50 + "\n")
|
self.text_area.insert(tk.END, "-" * 50 + "\n")
|
||||||
|
|
||||||
def display_data(self, data):
|
def display_data(self, data):
|
||||||
self.text_area.delete(1.0, tk.END)
|
self.text_area.delete(1.0, tk.END)
|
||||||
for item in data:
|
for item in data:
|
||||||
@@ -180,59 +192,58 @@ class ExpressionViewer:
|
|||||||
self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n")
|
self.text_area.insert(tk.END, f"风格: {item.get('style', 'N/A')}\n")
|
||||||
self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n")
|
self.text_area.insert(tk.END, f"计数: {item.get('count', 'N/A')}\n")
|
||||||
self.text_area.insert(tk.END, "-" * 50 + "\n")
|
self.text_area.insert(tk.END, "-" * 50 + "\n")
|
||||||
|
|
||||||
def update_similarity(self, *args):
|
def update_similarity(self, *args):
|
||||||
if not self.current_data:
|
if not self.current_data:
|
||||||
return
|
return
|
||||||
|
|
||||||
threshold = self.similarity_var.get()
|
threshold = self.similarity_var.get()
|
||||||
self.similarity_frame.winfo_children()[-1].config(text=f"相似度阈值: {threshold:.2f}")
|
self.similarity_frame.winfo_children()[-1].config(text=f"相似度阈值: {threshold:.2f}")
|
||||||
|
|
||||||
# 计算相似度
|
# 计算相似度
|
||||||
texts = [f"{item['situation']} {item['style']}" for item in self.current_data]
|
texts = [f"{item['situation']} {item['style']}" for item in self.current_data]
|
||||||
vectorizer = TfidfVectorizer()
|
vectorizer = TfidfVectorizer()
|
||||||
tfidf_matrix = vectorizer.fit_transform(texts)
|
tfidf_matrix = vectorizer.fit_transform(texts)
|
||||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||||
|
|
||||||
# 创建图
|
# 创建图
|
||||||
self.graph.clear()
|
self.graph.clear()
|
||||||
for i, item in enumerate(self.current_data):
|
for i, item in enumerate(self.current_data):
|
||||||
self.graph.add_node(i, label=f"{item['situation']}\n{item['style']}")
|
self.graph.add_node(i, label=f"{item['situation']}\n{item['style']}")
|
||||||
|
|
||||||
# 添加边
|
# 添加边
|
||||||
for i in range(len(self.current_data)):
|
for i in range(len(self.current_data)):
|
||||||
for j in range(i + 1, len(self.current_data)):
|
for j in range(i + 1, len(self.current_data)):
|
||||||
if similarity_matrix[i, j] > threshold:
|
if similarity_matrix[i, j] > threshold:
|
||||||
self.graph.add_edge(i, j, weight=similarity_matrix[i, j])
|
self.graph.add_edge(i, j, weight=similarity_matrix[i, j])
|
||||||
|
|
||||||
if self.show_graph_var.get():
|
if self.show_graph_var.get():
|
||||||
self.draw_graph()
|
self.draw_graph()
|
||||||
|
|
||||||
def draw_graph(self):
|
def draw_graph(self):
|
||||||
if self.canvas:
|
if self.canvas:
|
||||||
self.canvas.get_tk_widget().destroy()
|
self.canvas.get_tk_widget().destroy()
|
||||||
|
|
||||||
fig = plt.figure(figsize=(8, 6))
|
fig = plt.figure(figsize=(8, 6))
|
||||||
pos = nx.spring_layout(self.graph)
|
pos = nx.spring_layout(self.graph)
|
||||||
|
|
||||||
# 绘制节点
|
# 绘制节点
|
||||||
nx.draw_networkx_nodes(self.graph, pos, node_color='lightblue',
|
nx.draw_networkx_nodes(self.graph, pos, node_color="lightblue", node_size=1000, alpha=0.6)
|
||||||
node_size=1000, alpha=0.6)
|
|
||||||
|
|
||||||
# 绘制边
|
# 绘制边
|
||||||
nx.draw_networkx_edges(self.graph, pos, alpha=0.4)
|
nx.draw_networkx_edges(self.graph, pos, alpha=0.4)
|
||||||
|
|
||||||
# 添加标签
|
# 添加标签
|
||||||
labels = nx.get_node_attributes(self.graph, 'label')
|
labels = nx.get_node_attributes(self.graph, "label")
|
||||||
nx.draw_networkx_labels(self.graph, pos, labels, font_size=8)
|
nx.draw_networkx_labels(self.graph, pos, labels, font_size=8)
|
||||||
|
|
||||||
plt.title("表达方式关系图")
|
plt.title("表达方式关系图")
|
||||||
plt.axis('off')
|
plt.axis("off")
|
||||||
|
|
||||||
self.canvas = FigureCanvasTkAgg(fig, master=self.graph_frame)
|
self.canvas = FigureCanvasTkAgg(fig, master=self.graph_frame)
|
||||||
self.canvas.draw()
|
self.canvas.draw()
|
||||||
self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
|
self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
|
||||||
|
|
||||||
def toggle_graph(self):
|
def toggle_graph(self):
|
||||||
if self.show_graph_var.get():
|
if self.show_graph_var.get():
|
||||||
self.draw_graph()
|
self.draw_graph()
|
||||||
@@ -240,26 +251,28 @@ class ExpressionViewer:
|
|||||||
if self.canvas:
|
if self.canvas:
|
||||||
self.canvas.get_tk_widget().destroy()
|
self.canvas.get_tk_widget().destroy()
|
||||||
self.canvas = None
|
self.canvas = None
|
||||||
|
|
||||||
def filter_expressions(self, *args):
|
def filter_expressions(self, *args):
|
||||||
search_text = self.search_var.get().lower()
|
search_text = self.search_var.get().lower()
|
||||||
if not search_text:
|
if not search_text:
|
||||||
self.apply_sort()
|
self.apply_sort()
|
||||||
return
|
return
|
||||||
|
|
||||||
filtered_data = []
|
filtered_data = []
|
||||||
for item in self.current_data:
|
for item in self.current_data:
|
||||||
situation = item.get('situation', '').lower()
|
situation = item.get("situation", "").lower()
|
||||||
style = item.get('style', '').lower()
|
style = item.get("style", "").lower()
|
||||||
if search_text in situation or search_text in style:
|
if search_text in situation or search_text in style:
|
||||||
filtered_data.append(item)
|
filtered_data.append(item)
|
||||||
|
|
||||||
self.display_data(filtered_data)
|
self.display_data(filtered_data)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
root = tk.Tk()
|
root = tk.Tk()
|
||||||
app = ExpressionViewer(root)
|
# app = ExpressionViewer(root)
|
||||||
root.mainloop()
|
root.mainloop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -602,9 +602,10 @@ class EmojiManager:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查是否需要处理表情包(数量超过最大值或不足)
|
# 检查是否需要处理表情包(数量超过最大值或不足)
|
||||||
if global_config.emoji.steal_emoji and ((self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace) or (
|
if global_config.emoji.steal_emoji and (
|
||||||
self.emoji_num < self.emoji_num_max
|
(self.emoji_num > self.emoji_num_max and global_config.emoji.do_replace)
|
||||||
)):
|
or (self.emoji_num < self.emoji_num_max)
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
# 获取目录下所有图片文件
|
# 获取目录下所有图片文件
|
||||||
files_to_process = [
|
files_to_process = [
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"prompt: {prompt}")
|
print(f"prompt: {prompt}")
|
||||||
|
|
||||||
# 调用LLM处理记忆
|
# 调用LLM处理记忆
|
||||||
content = ""
|
content = ""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ class PluginAction(BaseAction):
|
|||||||
|
|
||||||
# 获取锚定消息(如果有)
|
# 获取锚定消息(如果有)
|
||||||
observations = self._services.get("observations", [])
|
observations = self._services.get("observations", [])
|
||||||
|
|
||||||
# 查找 ChattingObservation 实例
|
# 查找 ChattingObservation 实例
|
||||||
chatting_observation = None
|
chatting_observation = None
|
||||||
for obs in observations:
|
for obs in observations:
|
||||||
|
|||||||
@@ -217,14 +217,14 @@ class ActionPlanner(BasePlanner):
|
|||||||
action_data[key] = value
|
action_data[key] = value
|
||||||
|
|
||||||
action_data["identity"] = self_info
|
action_data["identity"] = self_info
|
||||||
|
|
||||||
extra_info_block = "\n".join(extra_info)
|
extra_info_block = "\n".join(extra_info)
|
||||||
extra_info_block += f"\n{structured_info}"
|
extra_info_block += f"\n{structured_info}"
|
||||||
if extra_info or structured_info:
|
if extra_info or structured_info:
|
||||||
extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策"
|
extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策"
|
||||||
else:
|
else:
|
||||||
extra_info_block = ""
|
extra_info_block = ""
|
||||||
|
|
||||||
action_data["extra_info_block"] = extra_info_block
|
action_data["extra_info_block"] = extra_info_block
|
||||||
|
|
||||||
# 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data
|
# 对于reply动作不需要额外处理,因为相关字段已经在上面的循环中添加到action_data
|
||||||
@@ -263,9 +263,6 @@ class ActionPlanner(BasePlanner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
action_result = {"action_type": action, "action_data": action_data, "reasoning": reasoning}
|
action_result = {"action_type": action, "action_data": action_data, "reasoning": reasoning}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
plan_result = {
|
plan_result = {
|
||||||
"action_result": action_result,
|
"action_result": action_result,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Any, List, Optional, Set, Tuple
|
from typing import Dict, Any, Tuple
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
|||||||
@@ -286,110 +286,110 @@ class MemoryManager:
|
|||||||
logger.error(f"生成总结时出错: {str(e)}")
|
logger.error(f"生成总结时出错: {str(e)}")
|
||||||
return default_summary
|
return default_summary
|
||||||
|
|
||||||
# async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
|
# async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
|
||||||
# """
|
# """
|
||||||
# 对记忆进行精简操作,根据要求修改要点、总结和概括
|
# 对记忆进行精简操作,根据要求修改要点、总结和概括
|
||||||
|
|
||||||
# Args:
|
# Args:
|
||||||
# memory_id: 记忆ID
|
# memory_id: 记忆ID
|
||||||
# requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
|
# requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
|
||||||
|
|
||||||
# Returns:
|
# Returns:
|
||||||
# 修改后的记忆总结字典
|
# 修改后的记忆总结字典
|
||||||
# """
|
# """
|
||||||
# # 获取指定ID的记忆项
|
# # 获取指定ID的记忆项
|
||||||
# logger.info(f"精简记忆: {memory_id}")
|
# logger.info(f"精简记忆: {memory_id}")
|
||||||
# memory_item = self.get_by_id(memory_id)
|
# memory_item = self.get_by_id(memory_id)
|
||||||
# if not memory_item:
|
# if not memory_item:
|
||||||
# raise ValueError(f"未找到ID为{memory_id}的记忆项")
|
# raise ValueError(f"未找到ID为{memory_id}的记忆项")
|
||||||
|
|
||||||
# # 增加精简次数
|
# # 增加精简次数
|
||||||
# memory_item.increase_compress_count()
|
# memory_item.increase_compress_count()
|
||||||
|
|
||||||
# summary = memory_item.summary
|
# summary = memory_item.summary
|
||||||
|
|
||||||
# # 使用LLM根据要求对总结、概括和要点进行精简修改
|
# # 使用LLM根据要求对总结、概括和要点进行精简修改
|
||||||
# prompt = f"""
|
# prompt = f"""
|
||||||
# 请根据以下要求,对记忆内容的主题和关键要点进行精简,模拟记忆的遗忘过程:
|
# 请根据以下要求,对记忆内容的主题和关键要点进行精简,模拟记忆的遗忘过程:
|
||||||
# 要求:{requirements}
|
# 要求:{requirements}
|
||||||
# 你可以随机对关键要点进行压缩,模糊或者丢弃,修改后,同样修改主题
|
# 你可以随机对关键要点进行压缩,模糊或者丢弃,修改后,同样修改主题
|
||||||
|
|
||||||
# 目前主题:{summary["brief"]}
|
# 目前主题:{summary["brief"]}
|
||||||
|
|
||||||
# 目前关键要点:
|
# 目前关键要点:
|
||||||
# {chr(10).join([f"- {point}" for point in summary.get("points", [])])}
|
# {chr(10).join([f"- {point}" for point in summary.get("points", [])])}
|
||||||
|
|
||||||
# 请生成修改后的主题和关键要点,遵循以下格式:
|
# 请生成修改后的主题和关键要点,遵循以下格式:
|
||||||
# ```json
|
# ```json
|
||||||
# {{
|
# {{
|
||||||
# "brief": "修改后的主题(20字以内)",
|
# "brief": "修改后的主题(20字以内)",
|
||||||
# "points": [
|
# "points": [
|
||||||
# "修改后的要点",
|
# "修改后的要点",
|
||||||
# "修改后的要点"
|
# "修改后的要点"
|
||||||
# ]
|
# ]
|
||||||
# }}
|
# }}
|
||||||
# ```
|
# ```
|
||||||
# 请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
# 请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||||||
# """
|
# """
|
||||||
# # 定义默认的精简结果
|
# # 定义默认的精简结果
|
||||||
# default_refined = {
|
# default_refined = {
|
||||||
# "brief": summary["brief"],
|
# "brief": summary["brief"],
|
||||||
# "points": summary.get("points", ["未知的要点"])[:1], # 默认只保留第一个要点
|
# "points": summary.get("points", ["未知的要点"])[:1], # 默认只保留第一个要点
|
||||||
# }
|
# }
|
||||||
|
|
||||||
# try:
|
# try:
|
||||||
# # 调用LLM修改总结、概括和要点
|
# # 调用LLM修改总结、概括和要点
|
||||||
# response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
# response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||||
# logger.debug(f"精简记忆响应: {response}")
|
# logger.debug(f"精简记忆响应: {response}")
|
||||||
# # 使用repair_json处理响应
|
# # 使用repair_json处理响应
|
||||||
# try:
|
# try:
|
||||||
# # 修复JSON格式
|
# # 修复JSON格式
|
||||||
# fixed_json_string = repair_json(response)
|
# fixed_json_string = repair_json(response)
|
||||||
|
|
||||||
# # 将修复后的字符串解析为Python对象
|
# # 将修复后的字符串解析为Python对象
|
||||||
# if isinstance(fixed_json_string, str):
|
# if isinstance(fixed_json_string, str):
|
||||||
# try:
|
# try:
|
||||||
# refined_data = json.loads(fixed_json_string)
|
# refined_data = json.loads(fixed_json_string)
|
||||||
# except json.JSONDecodeError as decode_error:
|
# except json.JSONDecodeError as decode_error:
|
||||||
# logger.error(f"JSON解析错误: {str(decode_error)}")
|
# logger.error(f"JSON解析错误: {str(decode_error)}")
|
||||||
# refined_data = default_refined
|
# refined_data = default_refined
|
||||||
# else:
|
# else:
|
||||||
# # 如果repair_json直接返回了字典对象,直接使用
|
# # 如果repair_json直接返回了字典对象,直接使用
|
||||||
# refined_data = fixed_json_string
|
# refined_data = fixed_json_string
|
||||||
|
|
||||||
# # 确保是字典类型
|
# # 确保是字典类型
|
||||||
# if not isinstance(refined_data, dict):
|
# if not isinstance(refined_data, dict):
|
||||||
# logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
|
# logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
|
||||||
# refined_data = default_refined
|
# refined_data = default_refined
|
||||||
|
|
||||||
# # 更新总结
|
# # 更新总结
|
||||||
# summary["brief"] = refined_data.get("brief", "主题未知的记忆")
|
# summary["brief"] = refined_data.get("brief", "主题未知的记忆")
|
||||||
|
|
||||||
# # 更新关键要点
|
# # 更新关键要点
|
||||||
# points = refined_data.get("points", [])
|
# points = refined_data.get("points", [])
|
||||||
# if isinstance(points, list) and points:
|
# if isinstance(points, list) and points:
|
||||||
# # 确保所有要点都是字符串
|
# # 确保所有要点都是字符串
|
||||||
# summary["points"] = [str(point) for point in points if point is not None]
|
# summary["points"] = [str(point) for point in points if point is not None]
|
||||||
# else:
|
# else:
|
||||||
# # 如果points不是列表或为空,使用默认值
|
# # 如果points不是列表或为空,使用默认值
|
||||||
# summary["points"] = ["主要要点已遗忘"]
|
# summary["points"] = ["主要要点已遗忘"]
|
||||||
|
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# logger.error(f"精简记忆出错: {str(e)}")
|
# logger.error(f"精简记忆出错: {str(e)}")
|
||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
|
|
||||||
# # 出错时使用简化的默认精简
|
# # 出错时使用简化的默认精简
|
||||||
# summary["brief"] = summary["brief"] + " (已简化)"
|
# summary["brief"] = summary["brief"] + " (已简化)"
|
||||||
# summary["points"] = summary.get("points", ["未知的要点"])[:1]
|
# summary["points"] = summary.get("points", ["未知的要点"])[:1]
|
||||||
|
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# logger.error(f"精简记忆调用LLM出错: {str(e)}")
|
# logger.error(f"精简记忆调用LLM出错: {str(e)}")
|
||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
|
|
||||||
# # 更新原记忆项的总结
|
# # 更新原记忆项的总结
|
||||||
# memory_item.set_summary(summary)
|
# memory_item.set_summary(summary)
|
||||||
|
|
||||||
# return memory_item
|
# return memory_item
|
||||||
|
|
||||||
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from typing import List, Any, Optional
|
from typing import List, Any, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from rich.traceback import install
|
|||||||
|
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
|
||||||
from peewee import Case
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -217,7 +216,7 @@ class Hippocampus:
|
|||||||
"""计算节点的特征值"""
|
"""计算节点的特征值"""
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
# 使用集合来去重,避免排序
|
# 使用集合来去重,避免排序
|
||||||
unique_items = set(str(item) for item in memory_items)
|
unique_items = set(str(item) for item in memory_items)
|
||||||
# 使用frozenset来保证顺序一致性
|
# 使用frozenset来保证顺序一致性
|
||||||
@@ -816,7 +815,7 @@ class EntorhinalCortex:
|
|||||||
timestamps = sample_scheduler.get_timestamp_array()
|
timestamps = sample_scheduler.get_timestamp_array()
|
||||||
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
|
# 使用 translate_timestamp_to_human_readable 并指定 mode="normal"
|
||||||
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
|
readable_timestamps = [translate_timestamp_to_human_readable(ts, mode="normal") for ts in timestamps]
|
||||||
for timestamp, readable_timestamp in zip(timestamps, readable_timestamps):
|
for _, readable_timestamp in zip(timestamps, readable_timestamps):
|
||||||
logger.debug(f"回忆往事: {readable_timestamp}")
|
logger.debug(f"回忆往事: {readable_timestamp}")
|
||||||
chat_samples = []
|
chat_samples = []
|
||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
@@ -843,16 +842,20 @@ class EntorhinalCortex:
|
|||||||
# 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds
|
# 定义时间范围:从目标时间戳开始,向后推移 time_window_seconds
|
||||||
timestamp_start = target_timestamp
|
timestamp_start = target_timestamp
|
||||||
timestamp_end = target_timestamp + time_window_seconds
|
timestamp_end = target_timestamp + time_window_seconds
|
||||||
|
|
||||||
chosen_message = get_raw_msg_by_timestamp(
|
chosen_message = get_raw_msg_by_timestamp(
|
||||||
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest"
|
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=1, limit_mode="earliest"
|
||||||
)
|
)
|
||||||
|
|
||||||
if chosen_message:
|
if chosen_message:
|
||||||
chat_id = chosen_message[0].get("chat_id")
|
chat_id = chosen_message[0].get("chat_id")
|
||||||
|
|
||||||
messages = get_raw_msg_by_timestamp_with_chat(
|
messages = get_raw_msg_by_timestamp_with_chat(
|
||||||
timestamp_start=timestamp_start, timestamp_end=timestamp_end, limit=chat_size, limit_mode="earliest", chat_id=chat_id
|
timestamp_start=timestamp_start,
|
||||||
|
timestamp_end=timestamp_end,
|
||||||
|
limit=chat_size,
|
||||||
|
limit_mode="earliest",
|
||||||
|
chat_id=chat_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if messages:
|
if messages:
|
||||||
@@ -885,7 +888,7 @@ class EntorhinalCortex:
|
|||||||
async def sync_memory_to_db(self):
|
async def sync_memory_to_db(self):
|
||||||
"""将记忆图同步到数据库"""
|
"""将记忆图同步到数据库"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 获取数据库中所有节点和内存中所有节点
|
# 获取数据库中所有节点和内存中所有节点
|
||||||
db_load_start = time.time()
|
db_load_start = time.time()
|
||||||
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
db_nodes = {node.concept: node for node in GraphNodes.select()}
|
||||||
@@ -943,13 +946,15 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
if concept not in db_nodes:
|
if concept not in db_nodes:
|
||||||
# 数据库中缺少的节点,添加到创建列表
|
# 数据库中缺少的节点,添加到创建列表
|
||||||
nodes_to_create.append({
|
nodes_to_create.append(
|
||||||
'concept': concept,
|
{
|
||||||
'memory_items': memory_items_json,
|
"concept": concept,
|
||||||
'hash': memory_hash,
|
"memory_items": memory_items_json,
|
||||||
'created_time': created_time,
|
"hash": memory_hash,
|
||||||
'last_modified': last_modified
|
"created_time": created_time,
|
||||||
})
|
"last_modified": last_modified,
|
||||||
|
}
|
||||||
|
)
|
||||||
logger.debug(f"[同步] 准备创建节点: {concept}, memory_items长度: {len(memory_items)}")
|
logger.debug(f"[同步] 准备创建节点: {concept}, memory_items长度: {len(memory_items)}")
|
||||||
else:
|
else:
|
||||||
# 获取数据库中节点的特征值
|
# 获取数据库中节点的特征值
|
||||||
@@ -958,12 +963,14 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
# 如果特征值不同,则添加到更新列表
|
# 如果特征值不同,则添加到更新列表
|
||||||
if db_hash != memory_hash:
|
if db_hash != memory_hash:
|
||||||
nodes_to_update.append({
|
nodes_to_update.append(
|
||||||
'concept': concept,
|
{
|
||||||
'memory_items': memory_items_json,
|
"concept": concept,
|
||||||
'hash': memory_hash,
|
"memory_items": memory_items_json,
|
||||||
'last_modified': last_modified
|
"hash": memory_hash,
|
||||||
})
|
"last_modified": last_modified,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 检查需要删除的节点
|
# 检查需要删除的节点
|
||||||
memory_concepts = {concept for concept, _ in memory_nodes}
|
memory_concepts = {concept for concept, _ in memory_nodes}
|
||||||
@@ -972,7 +979,9 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
node_process_end = time.time()
|
node_process_end = time.time()
|
||||||
logger.info(f"[同步] 处理节点数据耗时: {node_process_end - node_process_start:.2f}秒")
|
logger.info(f"[同步] 处理节点数据耗时: {node_process_end - node_process_start:.2f}秒")
|
||||||
logger.info(f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点")
|
logger.info(
|
||||||
|
f"[同步] 准备创建 {len(nodes_to_create)} 个节点,更新 {len(nodes_to_update)} 个节点,删除 {len(nodes_to_delete)} 个节点"
|
||||||
|
)
|
||||||
|
|
||||||
# 异步批量创建新节点
|
# 异步批量创建新节点
|
||||||
node_create_start = time.time()
|
node_create_start = time.time()
|
||||||
@@ -981,22 +990,24 @@ class EntorhinalCortex:
|
|||||||
# 验证所有要创建的节点数据
|
# 验证所有要创建的节点数据
|
||||||
valid_nodes_to_create = []
|
valid_nodes_to_create = []
|
||||||
for node_data in nodes_to_create:
|
for node_data in nodes_to_create:
|
||||||
if not node_data.get('memory_items'):
|
if not node_data.get("memory_items"):
|
||||||
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空")
|
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空")
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
# 验证 JSON 字符串
|
# 验证 JSON 字符串
|
||||||
json.loads(node_data['memory_items'])
|
json.loads(node_data["memory_items"])
|
||||||
valid_nodes_to_create.append(node_data)
|
valid_nodes_to_create.append(node_data)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
|
logger.warning(
|
||||||
|
f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if valid_nodes_to_create:
|
if valid_nodes_to_create:
|
||||||
# 使用异步批量插入
|
# 使用异步批量插入
|
||||||
batch_size = 100
|
batch_size = 100
|
||||||
for i in range(0, len(valid_nodes_to_create), batch_size):
|
for i in range(0, len(valid_nodes_to_create), batch_size):
|
||||||
batch = valid_nodes_to_create[i:i + batch_size]
|
batch = valid_nodes_to_create[i : i + batch_size]
|
||||||
await self._async_batch_create_nodes(batch)
|
await self._async_batch_create_nodes(batch)
|
||||||
logger.info(f"[同步] 成功创建 {len(valid_nodes_to_create)} 个节点")
|
logger.info(f"[同步] 成功创建 {len(valid_nodes_to_create)} 个节点")
|
||||||
else:
|
else:
|
||||||
@@ -1006,21 +1017,25 @@ class EntorhinalCortex:
|
|||||||
# 尝试逐个创建以找出问题节点
|
# 尝试逐个创建以找出问题节点
|
||||||
for node_data in nodes_to_create:
|
for node_data in nodes_to_create:
|
||||||
try:
|
try:
|
||||||
if not node_data.get('memory_items'):
|
if not node_data.get("memory_items"):
|
||||||
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空")
|
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 为空")
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
json.loads(node_data['memory_items'])
|
json.loads(node_data["memory_items"])
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
|
logger.warning(
|
||||||
|
f"[同步] 跳过创建节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
await self._async_create_node(node_data)
|
await self._async_create_node(node_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[同步] 创建节点失败: {node_data['concept']}, 错误: {e}")
|
logger.error(f"[同步] 创建节点失败: {node_data['concept']}, 错误: {e}")
|
||||||
# 从图中移除问题节点
|
# 从图中移除问题节点
|
||||||
self.memory_graph.G.remove_node(node_data['concept'])
|
self.memory_graph.G.remove_node(node_data["concept"])
|
||||||
node_create_end = time.time()
|
node_create_end = time.time()
|
||||||
logger.info(f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)")
|
logger.info(
|
||||||
|
f"[同步] 创建新节点耗时: {node_create_end - node_create_start:.2f}秒 (创建了 {len(nodes_to_create)} 个节点)"
|
||||||
|
)
|
||||||
|
|
||||||
# 异步批量更新节点
|
# 异步批量更新节点
|
||||||
node_update_start = time.time()
|
node_update_start = time.time()
|
||||||
@@ -1028,30 +1043,32 @@ class EntorhinalCortex:
|
|||||||
# 按批次更新节点,每批100个
|
# 按批次更新节点,每批100个
|
||||||
batch_size = 100
|
batch_size = 100
|
||||||
for i in range(0, len(nodes_to_update), batch_size):
|
for i in range(0, len(nodes_to_update), batch_size):
|
||||||
batch = nodes_to_update[i:i + batch_size]
|
batch = nodes_to_update[i : i + batch_size]
|
||||||
try:
|
try:
|
||||||
# 验证批次中的每个节点数据
|
# 验证批次中的每个节点数据
|
||||||
valid_batch = []
|
valid_batch = []
|
||||||
for node_data in batch:
|
for node_data in batch:
|
||||||
# 确保 memory_items 不为空且是有效的 JSON 字符串
|
# 确保 memory_items 不为空且是有效的 JSON 字符串
|
||||||
if not node_data.get('memory_items'):
|
if not node_data.get("memory_items"):
|
||||||
logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 为空")
|
logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 为空")
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
# 验证 JSON 字符串是否有效
|
# 验证 JSON 字符串是否有效
|
||||||
json.loads(node_data['memory_items'])
|
json.loads(node_data["memory_items"])
|
||||||
valid_batch.append(node_data)
|
valid_batch.append(node_data)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning(f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串")
|
logger.warning(
|
||||||
|
f"[同步] 跳过更新节点 {node_data['concept']}: memory_items 不是有效的 JSON 字符串"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not valid_batch:
|
if not valid_batch:
|
||||||
logger.warning(f"[同步] 批次 {i//batch_size + 1} 没有有效的节点可以更新")
|
logger.warning(f"[同步] 批次 {i // batch_size + 1} 没有有效的节点可以更新")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 异步批量更新节点
|
# 异步批量更新节点
|
||||||
await self._async_batch_update_nodes(valid_batch)
|
await self._async_batch_update_nodes(valid_batch)
|
||||||
logger.debug(f"[同步] 成功更新批次 {i//batch_size + 1} 中的 {len(valid_batch)} 个节点")
|
logger.debug(f"[同步] 成功更新批次 {i // batch_size + 1} 中的 {len(valid_batch)} 个节点")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[同步] 批量更新节点失败: {e}")
|
logger.error(f"[同步] 批量更新节点失败: {e}")
|
||||||
# 如果批量更新失败,尝试逐个更新
|
# 如果批量更新失败,尝试逐个更新
|
||||||
@@ -1061,17 +1078,21 @@ class EntorhinalCortex:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[同步] 更新节点失败: {node_data['concept']}, 错误: {e}")
|
logger.error(f"[同步] 更新节点失败: {node_data['concept']}, 错误: {e}")
|
||||||
# 从图中移除问题节点
|
# 从图中移除问题节点
|
||||||
self.memory_graph.G.remove_node(node_data['concept'])
|
self.memory_graph.G.remove_node(node_data["concept"])
|
||||||
|
|
||||||
node_update_end = time.time()
|
node_update_end = time.time()
|
||||||
logger.info(f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)")
|
logger.info(
|
||||||
|
f"[同步] 更新节点耗时: {node_update_end - node_update_start:.2f}秒 (更新了 {len(nodes_to_update)} 个节点)"
|
||||||
|
)
|
||||||
|
|
||||||
# 异步删除不存在的节点
|
# 异步删除不存在的节点
|
||||||
node_delete_start = time.time()
|
node_delete_start = time.time()
|
||||||
if nodes_to_delete:
|
if nodes_to_delete:
|
||||||
await self._async_delete_nodes(nodes_to_delete)
|
await self._async_delete_nodes(nodes_to_delete)
|
||||||
node_delete_end = time.time()
|
node_delete_end = time.time()
|
||||||
logger.info(f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)")
|
logger.info(
|
||||||
|
f"[同步] 删除节点耗时: {node_delete_end - node_delete_start:.2f}秒 (删除了 {len(nodes_to_delete)} 个节点)"
|
||||||
|
)
|
||||||
|
|
||||||
# 处理边的信息
|
# 处理边的信息
|
||||||
edge_load_start = time.time()
|
edge_load_start = time.time()
|
||||||
@@ -1106,24 +1127,28 @@ class EntorhinalCortex:
|
|||||||
|
|
||||||
if edge_key not in db_edge_dict:
|
if edge_key not in db_edge_dict:
|
||||||
# 添加新边到创建列表
|
# 添加新边到创建列表
|
||||||
edges_to_create.append({
|
edges_to_create.append(
|
||||||
'source': source,
|
{
|
||||||
'target': target,
|
"source": source,
|
||||||
'strength': strength,
|
"target": target,
|
||||||
'hash': edge_hash,
|
"strength": strength,
|
||||||
'created_time': created_time,
|
"hash": edge_hash,
|
||||||
'last_modified': last_modified
|
"created_time": created_time,
|
||||||
})
|
"last_modified": last_modified,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 检查边的特征值是否变化
|
# 检查边的特征值是否变化
|
||||||
if db_edge_dict[edge_key]["hash"] != edge_hash:
|
if db_edge_dict[edge_key]["hash"] != edge_hash:
|
||||||
edges_to_update.append({
|
edges_to_update.append(
|
||||||
'source': source,
|
{
|
||||||
'target': target,
|
"source": source,
|
||||||
'strength': strength,
|
"target": target,
|
||||||
'hash': edge_hash,
|
"strength": strength,
|
||||||
'last_modified': last_modified
|
"hash": edge_hash,
|
||||||
})
|
"last_modified": last_modified,
|
||||||
|
}
|
||||||
|
)
|
||||||
edge_process_end = time.time()
|
edge_process_end = time.time()
|
||||||
logger.info(f"[同步] 处理边数据耗时: {edge_process_end - edge_process_start:.2f}秒")
|
logger.info(f"[同步] 处理边数据耗时: {edge_process_end - edge_process_start:.2f}秒")
|
||||||
|
|
||||||
@@ -1132,20 +1157,24 @@ class EntorhinalCortex:
|
|||||||
if edges_to_create:
|
if edges_to_create:
|
||||||
batch_size = 100
|
batch_size = 100
|
||||||
for i in range(0, len(edges_to_create), batch_size):
|
for i in range(0, len(edges_to_create), batch_size):
|
||||||
batch = edges_to_create[i:i + batch_size]
|
batch = edges_to_create[i : i + batch_size]
|
||||||
await self._async_batch_create_edges(batch)
|
await self._async_batch_create_edges(batch)
|
||||||
edge_create_end = time.time()
|
edge_create_end = time.time()
|
||||||
logger.info(f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)")
|
logger.info(
|
||||||
|
f"[同步] 创建新边耗时: {edge_create_end - edge_create_start:.2f}秒 (创建了 {len(edges_to_create)} 条边)"
|
||||||
|
)
|
||||||
|
|
||||||
# 异步批量更新边
|
# 异步批量更新边
|
||||||
edge_update_start = time.time()
|
edge_update_start = time.time()
|
||||||
if edges_to_update:
|
if edges_to_update:
|
||||||
batch_size = 100
|
batch_size = 100
|
||||||
for i in range(0, len(edges_to_update), batch_size):
|
for i in range(0, len(edges_to_update), batch_size):
|
||||||
batch = edges_to_update[i:i + batch_size]
|
batch = edges_to_update[i : i + batch_size]
|
||||||
await self._async_batch_update_edges(batch)
|
await self._async_batch_update_edges(batch)
|
||||||
edge_update_end = time.time()
|
edge_update_end = time.time()
|
||||||
logger.info(f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)")
|
logger.info(
|
||||||
|
f"[同步] 更新边耗时: {edge_update_end - edge_update_start:.2f}秒 (更新了 {len(edges_to_update)} 条边)"
|
||||||
|
)
|
||||||
|
|
||||||
# 检查需要删除的边
|
# 检查需要删除的边
|
||||||
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
|
memory_edge_keys = {(source, target) for source, target, _ in memory_edges}
|
||||||
@@ -1157,7 +1186,9 @@ class EntorhinalCortex:
|
|||||||
if edges_to_delete:
|
if edges_to_delete:
|
||||||
await self._async_delete_edges(edges_to_delete)
|
await self._async_delete_edges(edges_to_delete)
|
||||||
edge_delete_end = time.time()
|
edge_delete_end = time.time()
|
||||||
logger.info(f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)")
|
logger.info(
|
||||||
|
f"[同步] 删除边耗时: {edge_delete_end - edge_delete_start:.2f}秒 (删除了 {len(edges_to_delete)} 条边)"
|
||||||
|
)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒")
|
logger.success(f"[同步] 总耗时: {end_time - start_time:.2f}秒")
|
||||||
@@ -1183,8 +1214,8 @@ class EntorhinalCortex:
|
|||||||
"""异步批量更新节点"""
|
"""异步批量更新节点"""
|
||||||
try:
|
try:
|
||||||
for node_data in nodes_data:
|
for node_data in nodes_data:
|
||||||
GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where(
|
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
|
||||||
GraphNodes.concept == node_data['concept']
|
GraphNodes.concept == node_data["concept"]
|
||||||
).execute()
|
).execute()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[同步] 批量更新节点失败: {e}")
|
logger.error(f"[同步] 批量更新节点失败: {e}")
|
||||||
@@ -1193,8 +1224,8 @@ class EntorhinalCortex:
|
|||||||
async def _async_update_node(self, node_data):
|
async def _async_update_node(self, node_data):
|
||||||
"""异步更新单个节点"""
|
"""异步更新单个节点"""
|
||||||
try:
|
try:
|
||||||
GraphNodes.update(**{k: v for k, v in node_data.items() if k != 'concept'}).where(
|
GraphNodes.update(**{k: v for k, v in node_data.items() if k != "concept"}).where(
|
||||||
GraphNodes.concept == node_data['concept']
|
GraphNodes.concept == node_data["concept"]
|
||||||
).execute()
|
).execute()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[同步] 更新节点失败: {e}")
|
logger.error(f"[同步] 更新节点失败: {e}")
|
||||||
@@ -1220,9 +1251,8 @@ class EntorhinalCortex:
|
|||||||
"""异步批量更新边"""
|
"""异步批量更新边"""
|
||||||
try:
|
try:
|
||||||
for edge_data in edges_data:
|
for edge_data in edges_data:
|
||||||
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ['source', 'target']}).where(
|
GraphEdges.update(**{k: v for k, v in edge_data.items() if k not in ["source", "target"]}).where(
|
||||||
(GraphEdges.source == edge_data['source']) &
|
(GraphEdges.source == edge_data["source"]) & (GraphEdges.target == edge_data["target"])
|
||||||
(GraphEdges.target == edge_data['target'])
|
|
||||||
).execute()
|
).execute()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[同步] 批量更新边失败: {e}")
|
logger.error(f"[同步] 批量更新边失败: {e}")
|
||||||
@@ -1232,10 +1262,7 @@ class EntorhinalCortex:
|
|||||||
"""异步删除边"""
|
"""异步删除边"""
|
||||||
try:
|
try:
|
||||||
for source, target in edge_keys:
|
for source, target in edge_keys:
|
||||||
GraphEdges.delete().where(
|
GraphEdges.delete().where((GraphEdges.source == source) & (GraphEdges.target == target)).execute()
|
||||||
(GraphEdges.source == source) &
|
|
||||||
(GraphEdges.target == target)
|
|
||||||
).execute()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[同步] 删除边失败: {e}")
|
logger.error(f"[同步] 删除边失败: {e}")
|
||||||
raise
|
raise
|
||||||
@@ -1406,7 +1433,7 @@ class ParahippocampalGyrus:
|
|||||||
if not input_text:
|
if not input_text:
|
||||||
logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。")
|
logger.warning("无法从提供的消息生成可读文本,跳过记忆压缩。")
|
||||||
return set(), {}
|
return set(), {}
|
||||||
|
|
||||||
current_YMD_time = datetime.datetime.now().strftime("%Y-%m-%d")
|
current_YMD_time = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||||
current_YMD_time_str = f"当前日期: {current_YMD_time}"
|
current_YMD_time_str = f"当前日期: {current_YMD_time}"
|
||||||
input_text = f"{current_YMD_time_str}\n{input_text}"
|
input_text = f"{current_YMD_time_str}\n{input_text}"
|
||||||
@@ -1533,8 +1560,7 @@ class ParahippocampalGyrus:
|
|||||||
logger.debug(f"连接同批次节点: {topic1} 和 {topic2}")
|
logger.debug(f"连接同批次节点: {topic1} 和 {topic2}")
|
||||||
all_added_edges.append(f"{topic1}-{topic2}")
|
all_added_edges.append(f"{topic1}-{topic2}")
|
||||||
self.memory_graph.connect_dot(topic1, topic2)
|
self.memory_graph.connect_dot(topic1, topic2)
|
||||||
|
|
||||||
|
|
||||||
progress = (i / len(memory_samples)) * 100
|
progress = (i / len(memory_samples)) * 100
|
||||||
bar_length = 30
|
bar_length = 30
|
||||||
filled_length = int(bar_length * i // len(memory_samples))
|
filled_length = int(bar_length * i // len(memory_samples))
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class NormalChat:
|
|||||||
# 初始化Normal Chat专用表达器
|
# 初始化Normal Chat专用表达器
|
||||||
self.expressor = NormalChatExpressor(self.chat_stream, self.stream_name)
|
self.expressor = NormalChatExpressor(self.chat_stream, self.stream_name)
|
||||||
self.replyer = DefaultReplyer(chat_id=self.stream_id)
|
self.replyer = DefaultReplyer(chat_id=self.stream_id)
|
||||||
|
|
||||||
self.replyer.chat_stream = self.chat_stream
|
self.replyer.chat_stream = self.chat_stream
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
@@ -243,9 +243,7 @@ class NormalChat:
|
|||||||
self.interest_dict.pop(msg_id, None)
|
self.interest_dict.pop(msg_id, None)
|
||||||
|
|
||||||
# 改为实例方法, 移除 chat 参数
|
# 改为实例方法, 移除 chat 参数
|
||||||
async def normal_response(
|
async def normal_response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
|
||||||
self, message: MessageRecv, is_mentioned: bool, interested_rate: float
|
|
||||||
) -> None:
|
|
||||||
# 新增:如果已停用,直接返回
|
# 新增:如果已停用,直接返回
|
||||||
if self._disabled:
|
if self._disabled:
|
||||||
logger.info(f"[{self.stream_name}] 已停用,忽略 normal_response。")
|
logger.info(f"[{self.stream_name}] 已停用,忽略 normal_response。")
|
||||||
@@ -287,7 +285,6 @@ class NormalChat:
|
|||||||
# 回复前处理
|
# 回复前处理
|
||||||
await willing_manager.before_generate_reply_handle(message.message_info.message_id)
|
await willing_manager.before_generate_reply_handle(message.message_info.message_id)
|
||||||
|
|
||||||
|
|
||||||
thinking_id = await self._create_thinking_message(message)
|
thinking_id = await self._create_thinking_message(message)
|
||||||
|
|
||||||
logger.debug(f"[{self.stream_name}] 创建捕捉器,thinking_id:{thinking_id}")
|
logger.debug(f"[{self.stream_name}] 创建捕捉器,thinking_id:{thinking_id}")
|
||||||
|
|||||||
@@ -142,7 +142,9 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
del local_storage["mmc_uuid"] # 删除本地存储的UUID
|
del local_storage["mmc_uuid"] # 删除本地存储的UUID
|
||||||
else:
|
else:
|
||||||
# 其他错误
|
# 其他错误
|
||||||
logger.warning(f"(此错误不会影响正常使用)状态未发送,状态码: {response.status_code}, 响应内容: {response.text}")
|
logger.warning(
|
||||||
|
f"(此错误不会影响正常使用)状态未发送,状态码: {response.status_code}, 响应内容: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
# 发送心跳
|
# 发送心跳
|
||||||
|
|||||||
@@ -104,7 +104,13 @@ class PersonalityExpression:
|
|||||||
if count >= self.max_calculations:
|
if count >= self.max_calculations:
|
||||||
logger.debug(f"对于风格 '{current_style_text}' 已达到最大计算次数 ({self.max_calculations})。跳过提取。")
|
logger.debug(f"对于风格 '{current_style_text}' 已达到最大计算次数 ({self.max_calculations})。跳过提取。")
|
||||||
# 即使跳过,也更新元数据以反映当前风格已被识别且计数已满
|
# 即使跳过,也更新元数据以反映当前风格已被识别且计数已满
|
||||||
self._write_meta_data({"last_style_text": current_style_text, "count": count, "last_update_time": meta_data.get("last_update_time")})
|
self._write_meta_data(
|
||||||
|
{
|
||||||
|
"last_style_text": current_style_text,
|
||||||
|
"count": count,
|
||||||
|
"last_update_time": meta_data.get("last_update_time"),
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 构建prompt
|
# 构建prompt
|
||||||
@@ -119,12 +125,18 @@ class PersonalityExpression:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"个性表达方式提取失败: {e}")
|
logger.error(f"个性表达方式提取失败: {e}")
|
||||||
# 如果提取失败,保存当前的风格和未增加的计数
|
# 如果提取失败,保存当前的风格和未增加的计数
|
||||||
self._write_meta_data({"last_style_text": current_style_text, "count": count, "last_update_time": meta_data.get("last_update_time")})
|
self._write_meta_data(
|
||||||
|
{
|
||||||
|
"last_style_text": current_style_text,
|
||||||
|
"count": count,
|
||||||
|
"last_update_time": meta_data.get("last_update_time"),
|
||||||
|
}
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"个性表达方式提取response: {response}")
|
logger.info(f"个性表达方式提取response: {response}")
|
||||||
# chat_id用personality
|
# chat_id用personality
|
||||||
|
|
||||||
# 转为dict并count=100
|
# 转为dict并count=100
|
||||||
if response != "":
|
if response != "":
|
||||||
expressions = self.parse_expression_response(response, "personality")
|
expressions = self.parse_expression_response(response, "personality")
|
||||||
@@ -136,25 +148,27 @@ class PersonalityExpression:
|
|||||||
existing_expressions = json.load(f)
|
existing_expressions = json.load(f)
|
||||||
except (json.JSONDecodeError, FileNotFoundError):
|
except (json.JSONDecodeError, FileNotFoundError):
|
||||||
logger.warning(f"无法读取或解析 {self.expressions_file_path},将创建新的表达文件。")
|
logger.warning(f"无法读取或解析 {self.expressions_file_path},将创建新的表达文件。")
|
||||||
|
|
||||||
# 创建新的表达方式
|
# 创建新的表达方式
|
||||||
new_expressions = []
|
new_expressions = []
|
||||||
for _, situation, style in expressions:
|
for _, situation, style in expressions:
|
||||||
new_expressions.append({"situation": situation, "style": style, "count": 1})
|
new_expressions.append({"situation": situation, "style": style, "count": 1})
|
||||||
|
|
||||||
# 合并表达方式,如果situation和style相同则累加count
|
# 合并表达方式,如果situation和style相同则累加count
|
||||||
merged_expressions = existing_expressions.copy()
|
merged_expressions = existing_expressions.copy()
|
||||||
for new_expr in new_expressions:
|
for new_expr in new_expressions:
|
||||||
found = False
|
found = False
|
||||||
for existing_expr in merged_expressions:
|
for existing_expr in merged_expressions:
|
||||||
if (existing_expr["situation"] == new_expr["situation"] and
|
if (
|
||||||
existing_expr["style"] == new_expr["style"]):
|
existing_expr["situation"] == new_expr["situation"]
|
||||||
|
and existing_expr["style"] == new_expr["style"]
|
||||||
|
):
|
||||||
existing_expr["count"] += new_expr["count"]
|
existing_expr["count"] += new_expr["count"]
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
if not found:
|
if not found:
|
||||||
merged_expressions.append(new_expr)
|
merged_expressions.append(new_expr)
|
||||||
|
|
||||||
# 超过50条时随机删除多余的,只保留50条
|
# 超过50条时随机删除多余的,只保留50条
|
||||||
if len(merged_expressions) > 50:
|
if len(merged_expressions) > 50:
|
||||||
remove_count = len(merged_expressions) - 50
|
remove_count = len(merged_expressions) - 50
|
||||||
@@ -168,11 +182,9 @@ class PersonalityExpression:
|
|||||||
# 成功提取后更新元数据
|
# 成功提取后更新元数据
|
||||||
count += 1
|
count += 1
|
||||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
self._write_meta_data({
|
self._write_meta_data(
|
||||||
"last_style_text": current_style_text,
|
{"last_style_text": current_style_text, "count": count, "last_update_time": current_time}
|
||||||
"count": count,
|
)
|
||||||
"last_update_time": current_time
|
|
||||||
})
|
|
||||||
logger.info(f"成功处理。风格 '{current_style_text}' 的计数现在是 {count},最后更新时间:{current_time}。")
|
logger.info(f"成功处理。风格 '{current_style_text}' 的计数现在是 {count},最后更新时间:{current_time}。")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"个性表达方式提取失败,模型返回空内容: {response}")
|
logger.warning(f"个性表达方式提取失败,模型返回空内容: {response}")
|
||||||
|
|||||||
Reference in New Issue
Block a user