feat:新增表达方式可视化脚本
This commit is contained in:
136
scripts/analyze_group_similarity.py
Normal file
136
scripts/analyze_group_similarity.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
import networkx as nx
|
||||||
|
import matplotlib as mpl
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
# 设置中文字体
|
||||||
|
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 使用微软雅黑
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
|
||||||
|
plt.rcParams['font.family'] = 'sans-serif'
|
||||||
|
|
||||||
|
# 获取脚本所在目录
|
||||||
|
SCRIPT_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
def get_group_name(stream_id):
|
||||||
|
"""从数据库中获取群组名称"""
|
||||||
|
conn = sqlite3.connect('data/maibot.db')
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute('''
|
||||||
|
SELECT group_name, user_nickname, platform
|
||||||
|
FROM chat_streams
|
||||||
|
WHERE stream_id = ?
|
||||||
|
''', (stream_id,))
|
||||||
|
|
||||||
|
result = cursor.fetchone()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
group_name, user_nickname, platform = result
|
||||||
|
if group_name:
|
||||||
|
return group_name
|
||||||
|
if user_nickname:
|
||||||
|
return user_nickname
|
||||||
|
if platform:
|
||||||
|
return f"{platform}-{stream_id[:8]}"
|
||||||
|
return stream_id
|
||||||
|
|
||||||
|
def load_group_expressions(group_dir):
|
||||||
|
"""加载单个群组的表达方式数据"""
|
||||||
|
json_path = Path(group_dir) / "expressions.json"
|
||||||
|
if not json_path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
with open(json_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# 将所有表达方式合并成一个文本
|
||||||
|
all_expressions = []
|
||||||
|
for item in data:
|
||||||
|
all_expressions.extend([item['style']] * item['count'])
|
||||||
|
|
||||||
|
return ' '.join(all_expressions)
|
||||||
|
|
||||||
|
def analyze_group_similarity():
|
||||||
|
# 获取所有群组目录
|
||||||
|
base_dir = Path("data/expression/learnt_style")
|
||||||
|
group_dirs = [d for d in base_dir.iterdir() if d.is_dir()]
|
||||||
|
group_ids = [d.name for d in group_dirs]
|
||||||
|
|
||||||
|
# 获取群组名称
|
||||||
|
group_names = [get_group_name(group_id) for group_id in group_ids]
|
||||||
|
|
||||||
|
# 加载所有群组的表达方式
|
||||||
|
group_texts = [load_group_expressions(d) for d in group_dirs]
|
||||||
|
|
||||||
|
# 使用TF-IDF向量化文本
|
||||||
|
vectorizer = TfidfVectorizer()
|
||||||
|
tfidf_matrix = vectorizer.fit_transform(group_texts)
|
||||||
|
|
||||||
|
# 计算余弦相似度
|
||||||
|
similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||||
|
|
||||||
|
# 对相似度矩阵进行对数变换
|
||||||
|
log_similarity_matrix = np.log1p(similarity_matrix)
|
||||||
|
|
||||||
|
# 创建热力图
|
||||||
|
plt.figure(figsize=(15, 12))
|
||||||
|
sns.heatmap(log_similarity_matrix,
|
||||||
|
xticklabels=group_names,
|
||||||
|
yticklabels=group_names,
|
||||||
|
cmap='YlOrRd',
|
||||||
|
annot=True,
|
||||||
|
fmt='.2f',
|
||||||
|
vmin=0,
|
||||||
|
vmax=np.log1p(0.2)) # 调整最大值以匹配对数变换
|
||||||
|
plt.title('群组表达方式相似度热力图 (对数变换)')
|
||||||
|
plt.xticks(rotation=45, ha='right')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(SCRIPT_DIR / 'group_similarity_heatmap.png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# 创建网络图
|
||||||
|
G = nx.Graph()
|
||||||
|
|
||||||
|
# 添加节点
|
||||||
|
for group_id, group_name in zip(group_ids, group_names):
|
||||||
|
G.add_node(group_id, label=group_name)
|
||||||
|
|
||||||
|
# 添加边(使用对数变换后的相似度)
|
||||||
|
for i in range(len(group_ids)):
|
||||||
|
for j in range(i+1, len(group_ids)):
|
||||||
|
if log_similarity_matrix[i][j] > np.log1p(0.05): # 调整阈值
|
||||||
|
G.add_edge(group_ids[i], group_ids[j],
|
||||||
|
weight=log_similarity_matrix[i][j])
|
||||||
|
|
||||||
|
# 绘制网络图
|
||||||
|
plt.figure(figsize=(20, 20))
|
||||||
|
pos = nx.spring_layout(G, k=1, iterations=50)
|
||||||
|
|
||||||
|
# 绘制节点
|
||||||
|
nx.draw_networkx_nodes(G, pos, node_size=20000, node_color='lightblue', alpha=0.8)
|
||||||
|
|
||||||
|
# 绘制边
|
||||||
|
edges = G.edges()
|
||||||
|
weights = [G[u][v]['weight'] * 40 for u, v in edges] # 增加线条粗细系数
|
||||||
|
nx.draw_networkx_edges(G, pos, width=weights, alpha=0.6, edge_color='gray')
|
||||||
|
|
||||||
|
# 添加标签
|
||||||
|
labels = {node: G.nodes[node]['label'] for node in G.nodes()}
|
||||||
|
nx.draw_networkx_labels(G, pos, labels, font_size=20, font_weight='bold')
|
||||||
|
|
||||||
|
plt.title('群组表达方式相似度网络图\n(连线粗细表示对数变换后的相似度)')
|
||||||
|
plt.axis('off')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(SCRIPT_DIR / 'group_similarity_network.png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
analyze_group_similarity()
|
||||||
BIN
scripts/group_similarity_heatmap.png
Normal file
BIN
scripts/group_similarity_heatmap.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 657 KiB |
BIN
scripts/group_similarity_network.png
Normal file
BIN
scripts/group_similarity_network.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 884 KiB |
Reference in New Issue
Block a user