ruff
This commit is contained in:
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
@@ -9,13 +9,15 @@ import sqlite3
|
|||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
def clean_group_name(name: str) -> str:
|
def clean_group_name(name: str) -> str:
|
||||||
"""清理群组名称,只保留中文和英文字符"""
|
"""清理群组名称,只保留中文和英文字符"""
|
||||||
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name)
|
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
|
||||||
if not cleaned:
|
if not cleaned:
|
||||||
cleaned = datetime.now().strftime("%Y%m%d")
|
cleaned = datetime.now().strftime("%Y%m%d")
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id: str) -> str:
|
def get_group_name(stream_id: str) -> str:
|
||||||
"""从数据库中获取群组名称"""
|
"""从数据库中获取群组名称"""
|
||||||
conn = sqlite3.connect("data/maibot.db")
|
conn = sqlite3.connect("data/maibot.db")
|
||||||
@@ -43,6 +45,7 @@ def get_group_name(stream_id: str) -> str:
|
|||||||
return clean_group_name(f"{platform}{stream_id[:8]}")
|
return clean_group_name(f"{platform}{stream_id[:8]}")
|
||||||
return stream_id
|
return stream_id
|
||||||
|
|
||||||
|
|
||||||
def format_timestamp(timestamp: float) -> str:
|
def format_timestamp(timestamp: float) -> str:
|
||||||
"""将时间戳转换为可读的时间格式"""
|
"""将时间戳转换为可读的时间格式"""
|
||||||
if not timestamp:
|
if not timestamp:
|
||||||
@@ -50,132 +53,140 @@ def format_timestamp(timestamp: float) -> str:
|
|||||||
try:
|
try:
|
||||||
dt = datetime.fromtimestamp(timestamp)
|
dt = datetime.fromtimestamp(timestamp)
|
||||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
except:
|
except Exception as e:
|
||||||
|
print(f"时间戳格式化错误: {e}")
|
||||||
return "未知"
|
return "未知"
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> List[Dict]:
|
def load_expressions(chat_id: str) -> List[Dict]:
|
||||||
"""加载指定群聊的表达方式"""
|
"""加载指定群聊的表达方式"""
|
||||||
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||||
|
|
||||||
style_exprs = []
|
style_exprs = []
|
||||||
|
|
||||||
if os.path.exists(style_file):
|
if os.path.exists(style_file):
|
||||||
with open(style_file, "r", encoding="utf-8") as f:
|
with open(style_file, "r", encoding="utf-8") as f:
|
||||||
style_exprs = json.load(f)
|
style_exprs = json.load(f)
|
||||||
|
|
||||||
return style_exprs
|
return style_exprs
|
||||||
|
|
||||||
|
|
||||||
def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]:
|
def find_similar_expressions(expressions: List[Dict], top_k: int = 5) -> Dict[str, List[Tuple[str, float]]]:
|
||||||
"""找出每个表达方式最相似的top_k个表达方式"""
|
"""找出每个表达方式最相似的top_k个表达方式"""
|
||||||
if not expressions:
|
if not expressions:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# 分别准备情景和表达方式的文本数据
|
# 分别准备情景和表达方式的文本数据
|
||||||
situations = [expr['situation'] for expr in expressions]
|
situations = [expr["situation"] for expr in expressions]
|
||||||
styles = [expr['style'] for expr in expressions]
|
styles = [expr["style"] for expr in expressions]
|
||||||
|
|
||||||
# 使用TF-IDF向量化
|
# 使用TF-IDF向量化
|
||||||
vectorizer = TfidfVectorizer()
|
vectorizer = TfidfVectorizer()
|
||||||
situation_matrix = vectorizer.fit_transform(situations)
|
situation_matrix = vectorizer.fit_transform(situations)
|
||||||
style_matrix = vectorizer.fit_transform(styles)
|
style_matrix = vectorizer.fit_transform(styles)
|
||||||
|
|
||||||
# 计算余弦相似度
|
# 计算余弦相似度
|
||||||
situation_similarity = cosine_similarity(situation_matrix)
|
situation_similarity = cosine_similarity(situation_matrix)
|
||||||
style_similarity = cosine_similarity(style_matrix)
|
style_similarity = cosine_similarity(style_matrix)
|
||||||
|
|
||||||
# 对每个表达方式找出最相似的top_k个
|
# 对每个表达方式找出最相似的top_k个
|
||||||
similar_expressions = {}
|
similar_expressions = {}
|
||||||
for i, expr in enumerate(expressions):
|
for i, _ in enumerate(expressions):
|
||||||
# 获取相似度分数
|
# 获取相似度分数
|
||||||
situation_scores = situation_similarity[i]
|
situation_scores = situation_similarity[i]
|
||||||
style_scores = style_similarity[i]
|
style_scores = style_similarity[i]
|
||||||
|
|
||||||
# 获取top_k的索引(排除自己)
|
# 获取top_k的索引(排除自己)
|
||||||
situation_indices = np.argsort(situation_scores)[::-1][1:top_k+1]
|
situation_indices = np.argsort(situation_scores)[::-1][1 : top_k + 1]
|
||||||
style_indices = np.argsort(style_scores)[::-1][1:top_k+1]
|
style_indices = np.argsort(style_scores)[::-1][1 : top_k + 1]
|
||||||
|
|
||||||
similar_situations = []
|
similar_situations = []
|
||||||
similar_styles = []
|
similar_styles = []
|
||||||
|
|
||||||
# 处理相似情景
|
# 处理相似情景
|
||||||
for idx in situation_indices:
|
for idx in situation_indices:
|
||||||
if situation_scores[idx] > 0: # 只保留有相似度的
|
if situation_scores[idx] > 0: # 只保留有相似度的
|
||||||
similar_situations.append((
|
similar_situations.append(
|
||||||
expressions[idx]['situation'],
|
(
|
||||||
expressions[idx]['style'], # 添加对应的原始表达
|
expressions[idx]["situation"],
|
||||||
situation_scores[idx]
|
expressions[idx]["style"], # 添加对应的原始表达
|
||||||
))
|
situation_scores[idx],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 处理相似表达
|
# 处理相似表达
|
||||||
for idx in style_indices:
|
for idx in style_indices:
|
||||||
if style_scores[idx] > 0: # 只保留有相似度的
|
if style_scores[idx] > 0: # 只保留有相似度的
|
||||||
similar_styles.append((
|
similar_styles.append(
|
||||||
expressions[idx]['style'],
|
(
|
||||||
expressions[idx]['situation'], # 添加对应的原始情景
|
expressions[idx]["style"],
|
||||||
style_scores[idx]
|
expressions[idx]["situation"], # 添加对应的原始情景
|
||||||
))
|
style_scores[idx],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if similar_situations or similar_styles:
|
if similar_situations or similar_styles:
|
||||||
similar_expressions[i] = {
|
similar_expressions[i] = {"situations": similar_situations, "styles": similar_styles}
|
||||||
'situations': similar_situations,
|
|
||||||
'styles': similar_styles
|
|
||||||
}
|
|
||||||
|
|
||||||
return similar_expressions
|
return similar_expressions
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 获取所有群聊ID
|
# 获取所有群聊ID
|
||||||
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
||||||
chat_ids = [os.path.basename(d) for d in style_dirs]
|
chat_ids = [os.path.basename(d) for d in style_dirs]
|
||||||
|
|
||||||
if not chat_ids:
|
if not chat_ids:
|
||||||
print("没有找到任何群聊的表达方式数据")
|
print("没有找到任何群聊的表达方式数据")
|
||||||
return
|
return
|
||||||
|
|
||||||
print("可用的群聊:")
|
print("可用的群聊:")
|
||||||
for i, chat_id in enumerate(chat_ids, 1):
|
for i, chat_id in enumerate(chat_ids, 1):
|
||||||
group_name = get_group_name(chat_id)
|
group_name = get_group_name(chat_id)
|
||||||
print(f"{i}. {group_name}")
|
print(f"{i}. {group_name}")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
choice = int(input("\n请选择要分析的群聊编号 (输入0退出): "))
|
choice = int(input("\n请选择要分析的群聊编号 (输入0退出): "))
|
||||||
if choice == 0:
|
if choice == 0:
|
||||||
break
|
break
|
||||||
if 1 <= choice <= len(chat_ids):
|
if 1 <= choice <= len(chat_ids):
|
||||||
chat_id = chat_ids[choice-1]
|
chat_id = chat_ids[choice - 1]
|
||||||
break
|
break
|
||||||
print("无效的选择,请重试")
|
print("无效的选择,请重试")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print("请输入有效的数字")
|
print("请输入有效的数字")
|
||||||
|
|
||||||
if choice == 0:
|
if choice == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 加载表达方式
|
# 加载表达方式
|
||||||
style_exprs = load_expressions(chat_id)
|
style_exprs = load_expressions(chat_id)
|
||||||
|
|
||||||
group_name = get_group_name(chat_id)
|
group_name = get_group_name(chat_id)
|
||||||
print(f"\n分析群聊 {group_name} 的表达方式:")
|
print(f"\n分析群聊 {group_name} 的表达方式:")
|
||||||
|
|
||||||
similar_styles = find_similar_expressions(style_exprs)
|
similar_styles = find_similar_expressions(style_exprs)
|
||||||
for i, expr in enumerate(style_exprs):
|
for i, expr in enumerate(style_exprs):
|
||||||
if i in similar_styles:
|
if i in similar_styles:
|
||||||
print("\n" + "-" * 20)
|
print("\n" + "-" * 20)
|
||||||
print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}")
|
print(f"表达方式:{expr['style']} <---> 情景:{expr['situation']}")
|
||||||
|
|
||||||
if similar_styles[i]['styles']:
|
if similar_styles[i]["styles"]:
|
||||||
print("\n\033[33m相似表达:\033[0m")
|
print("\n\033[33m相似表达:\033[0m")
|
||||||
for similar_style, original_situation, score in similar_styles[i]['styles']:
|
for similar_style, original_situation, score in similar_styles[i]["styles"]:
|
||||||
print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m")
|
print(f"\033[33m{similar_style},score:{score:.3f},对应情景:{original_situation}\033[0m")
|
||||||
|
|
||||||
if similar_styles[i]['situations']:
|
if similar_styles[i]["situations"]:
|
||||||
print("\n\033[32m相似情景:\033[0m")
|
print("\n\033[32m相似情景:\033[0m")
|
||||||
for similar_situation, original_style, score in similar_styles[i]['situations']:
|
for similar_situation, original_style, score in similar_styles[i]["situations"]:
|
||||||
print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m")
|
print(f"\033[32m{similar_situation},score:{score:.3f},对应表达:{original_style}\033[0m")
|
||||||
|
|
||||||
print(f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}")
|
print(
|
||||||
|
f"\n激活值:{expr.get('count', 1):.3f},上次激活时间:{format_timestamp(expr.get('last_active_time'))}"
|
||||||
|
)
|
||||||
print("-" * 20)
|
print("-" * 20)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -6,15 +6,17 @@ from datetime import datetime
|
|||||||
from typing import Dict, List, Any
|
from typing import Dict, List, Any
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
def clean_group_name(name: str) -> str:
|
def clean_group_name(name: str) -> str:
|
||||||
"""清理群组名称,只保留中文和英文字符"""
|
"""清理群组名称,只保留中文和英文字符"""
|
||||||
# 提取中文和英文字符
|
# 提取中文和英文字符
|
||||||
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name)
|
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
|
||||||
# 如果清理后为空,使用当前日期
|
# 如果清理后为空,使用当前日期
|
||||||
if not cleaned:
|
if not cleaned:
|
||||||
cleaned = datetime.now().strftime("%Y%m%d")
|
cleaned = datetime.now().strftime("%Y%m%d")
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id: str) -> str:
|
def get_group_name(stream_id: str) -> str:
|
||||||
"""从数据库中获取群组名称"""
|
"""从数据库中获取群组名称"""
|
||||||
conn = sqlite3.connect("data/maibot.db")
|
conn = sqlite3.connect("data/maibot.db")
|
||||||
@@ -42,41 +44,44 @@ def get_group_name(stream_id: str) -> str:
|
|||||||
return clean_group_name(f"{platform}{stream_id[:8]}")
|
return clean_group_name(f"{platform}{stream_id[:8]}")
|
||||||
return stream_id
|
return stream_id
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
def load_expressions(chat_id: str) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
"""加载指定群组的表达方式"""
|
"""加载指定群组的表达方式"""
|
||||||
learnt_style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
learnt_style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||||
learnt_grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
learnt_grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||||
personality_file = os.path.join("data", "expression", "personality", "expressions.json")
|
personality_file = os.path.join("data", "expression", "personality", "expressions.json")
|
||||||
|
|
||||||
style_expressions = []
|
style_expressions = []
|
||||||
grammar_expressions = []
|
grammar_expressions = []
|
||||||
personality_expressions = []
|
personality_expressions = []
|
||||||
|
|
||||||
if os.path.exists(learnt_style_file):
|
if os.path.exists(learnt_style_file):
|
||||||
with open(learnt_style_file, "r", encoding="utf-8") as f:
|
with open(learnt_style_file, "r", encoding="utf-8") as f:
|
||||||
style_expressions = json.load(f)
|
style_expressions = json.load(f)
|
||||||
|
|
||||||
if os.path.exists(learnt_grammar_file):
|
if os.path.exists(learnt_grammar_file):
|
||||||
with open(learnt_grammar_file, "r", encoding="utf-8") as f:
|
with open(learnt_grammar_file, "r", encoding="utf-8") as f:
|
||||||
grammar_expressions = json.load(f)
|
grammar_expressions = json.load(f)
|
||||||
|
|
||||||
if os.path.exists(personality_file):
|
if os.path.exists(personality_file):
|
||||||
with open(personality_file, "r", encoding="utf-8") as f:
|
with open(personality_file, "r", encoding="utf-8") as f:
|
||||||
personality_expressions = json.load(f)
|
personality_expressions = json.load(f)
|
||||||
|
|
||||||
return style_expressions, grammar_expressions, personality_expressions
|
return style_expressions, grammar_expressions, personality_expressions
|
||||||
|
|
||||||
|
|
||||||
def format_time(timestamp: float) -> str:
|
def format_time(timestamp: float) -> str:
|
||||||
"""格式化时间戳为可读字符串"""
|
"""格式化时间戳为可读字符串"""
|
||||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
|
||||||
def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
|
def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
|
||||||
"""写入表达方式列表"""
|
"""写入表达方式列表"""
|
||||||
if not expressions:
|
if not expressions:
|
||||||
f.write(f"{title}:暂无数据\n")
|
f.write(f"{title}:暂无数据\n")
|
||||||
f.write("-" * 40 + "\n")
|
f.write("-" * 40 + "\n")
|
||||||
return
|
return
|
||||||
|
|
||||||
f.write(f"{title}:\n")
|
f.write(f"{title}:\n")
|
||||||
for expr in expressions:
|
for expr in expressions:
|
||||||
count = expr.get("count", 0)
|
count = expr.get("count", 0)
|
||||||
@@ -87,103 +92,111 @@ def write_expressions(f, expressions: List[Dict[str, Any]], title: str):
|
|||||||
f.write(f"最后活跃: {format_time(last_active)}\n")
|
f.write(f"最后活跃: {format_time(last_active)}\n")
|
||||||
f.write("-" * 40 + "\n")
|
f.write("-" * 40 + "\n")
|
||||||
|
|
||||||
def write_group_report(group_file: str, group_name: str, chat_id: str, style_exprs: List[Dict[str, Any]], grammar_exprs: List[Dict[str, Any]]):
|
|
||||||
|
def write_group_report(
|
||||||
|
group_file: str,
|
||||||
|
group_name: str,
|
||||||
|
chat_id: str,
|
||||||
|
style_exprs: List[Dict[str, Any]],
|
||||||
|
grammar_exprs: List[Dict[str, Any]],
|
||||||
|
):
|
||||||
"""写入群组详细报告"""
|
"""写入群组详细报告"""
|
||||||
with open(group_file, "w", encoding="utf-8") as gf:
|
with open(group_file, "w", encoding="utf-8") as gf:
|
||||||
gf.write(f"群组: {group_name} (ID: {chat_id})\n")
|
gf.write(f"群组: {group_name} (ID: {chat_id})\n")
|
||||||
gf.write("=" * 80 + "\n\n")
|
gf.write("=" * 80 + "\n\n")
|
||||||
|
|
||||||
# 写入语言风格
|
# 写入语言风格
|
||||||
gf.write("【语言风格】\n")
|
gf.write("【语言风格】\n")
|
||||||
gf.write("=" * 40 + "\n")
|
gf.write("=" * 40 + "\n")
|
||||||
write_expressions(gf, style_exprs, "语言风格")
|
write_expressions(gf, style_exprs, "语言风格")
|
||||||
gf.write("\n")
|
gf.write("\n")
|
||||||
|
|
||||||
# 写入句法特点
|
# 写入句法特点
|
||||||
gf.write("【句法特点】\n")
|
gf.write("【句法特点】\n")
|
||||||
gf.write("=" * 40 + "\n")
|
gf.write("=" * 40 + "\n")
|
||||||
write_expressions(gf, grammar_exprs, "句法特点")
|
write_expressions(gf, grammar_exprs, "句法特点")
|
||||||
|
|
||||||
|
|
||||||
def analyze_expressions():
|
def analyze_expressions():
|
||||||
"""分析所有群组的表达方式"""
|
"""分析所有群组的表达方式"""
|
||||||
# 获取所有群组ID
|
# 获取所有群组ID
|
||||||
style_dir = os.path.join("data", "expression", "learnt_style")
|
style_dir = os.path.join("data", "expression", "learnt_style")
|
||||||
chat_ids = [d for d in os.listdir(style_dir) if os.path.isdir(os.path.join(style_dir, d))]
|
chat_ids = [d for d in os.listdir(style_dir) if os.path.isdir(os.path.join(style_dir, d))]
|
||||||
|
|
||||||
# 创建输出目录
|
# 创建输出目录
|
||||||
output_dir = "data/expression_analysis"
|
output_dir = "data/expression_analysis"
|
||||||
personality_dir = os.path.join(output_dir, "personality")
|
personality_dir = os.path.join(output_dir, "personality")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
os.makedirs(personality_dir, exist_ok=True)
|
os.makedirs(personality_dir, exist_ok=True)
|
||||||
|
|
||||||
# 生成时间戳
|
# 生成时间戳
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
# 创建总报告
|
# 创建总报告
|
||||||
summary_file = os.path.join(output_dir, f"summary_{timestamp}.txt")
|
summary_file = os.path.join(output_dir, f"summary_{timestamp}.txt")
|
||||||
with open(summary_file, "w", encoding="utf-8") as f:
|
with open(summary_file, "w", encoding="utf-8") as f:
|
||||||
f.write(f"表达方式分析报告 - 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
f.write(f"表达方式分析报告 - 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||||
f.write("=" * 80 + "\n\n")
|
f.write("=" * 80 + "\n\n")
|
||||||
|
|
||||||
# 先处理人格表达
|
# 先处理人格表达
|
||||||
personality_exprs = []
|
personality_exprs = []
|
||||||
personality_file = os.path.join("data", "expression", "personality", "expressions.json")
|
personality_file = os.path.join("data", "expression", "personality", "expressions.json")
|
||||||
if os.path.exists(personality_file):
|
if os.path.exists(personality_file):
|
||||||
with open(personality_file, "r", encoding="utf-8") as pf:
|
with open(personality_file, "r", encoding="utf-8") as pf:
|
||||||
personality_exprs = json.load(pf)
|
personality_exprs = json.load(pf)
|
||||||
|
|
||||||
# 保存人格表达总数
|
# 保存人格表达总数
|
||||||
total_personality = len(personality_exprs)
|
total_personality = len(personality_exprs)
|
||||||
|
|
||||||
# 排序并取前20条
|
# 排序并取前20条
|
||||||
personality_exprs.sort(key=lambda x: x.get("count", 0), reverse=True)
|
personality_exprs.sort(key=lambda x: x.get("count", 0), reverse=True)
|
||||||
personality_exprs = personality_exprs[:20]
|
personality_exprs = personality_exprs[:20]
|
||||||
|
|
||||||
# 写入人格表达报告
|
# 写入人格表达报告
|
||||||
personality_report = os.path.join(personality_dir, f"expressions_{timestamp}.txt")
|
personality_report = os.path.join(personality_dir, f"expressions_{timestamp}.txt")
|
||||||
with open(personality_report, "w", encoding="utf-8") as pf:
|
with open(personality_report, "w", encoding="utf-8") as pf:
|
||||||
pf.write("【人格表达方式】\n")
|
pf.write("【人格表达方式】\n")
|
||||||
pf.write("=" * 40 + "\n")
|
pf.write("=" * 40 + "\n")
|
||||||
write_expressions(pf, personality_exprs, "人格表达")
|
write_expressions(pf, personality_exprs, "人格表达")
|
||||||
|
|
||||||
# 写入总报告摘要中的人格表达部分
|
# 写入总报告摘要中的人格表达部分
|
||||||
f.write("【人格表达方式】\n")
|
f.write("【人格表达方式】\n")
|
||||||
f.write("=" * 40 + "\n")
|
f.write("=" * 40 + "\n")
|
||||||
f.write(f"人格表达总数: {total_personality} (显示前20条)\n")
|
f.write(f"人格表达总数: {total_personality} (显示前20条)\n")
|
||||||
f.write(f"详细报告: {personality_report}\n")
|
f.write(f"详细报告: {personality_report}\n")
|
||||||
f.write("-" * 40 + "\n\n")
|
f.write("-" * 40 + "\n\n")
|
||||||
|
|
||||||
# 处理各个群组的表达方式
|
# 处理各个群组的表达方式
|
||||||
f.write("【群组表达方式】\n")
|
f.write("【群组表达方式】\n")
|
||||||
f.write("=" * 40 + "\n\n")
|
f.write("=" * 40 + "\n\n")
|
||||||
|
|
||||||
for chat_id in chat_ids:
|
for chat_id in chat_ids:
|
||||||
style_exprs, grammar_exprs, _ = load_expressions(chat_id)
|
style_exprs, grammar_exprs, _ = load_expressions(chat_id)
|
||||||
|
|
||||||
# 保存总数
|
# 保存总数
|
||||||
total_style = len(style_exprs)
|
total_style = len(style_exprs)
|
||||||
total_grammar = len(grammar_exprs)
|
total_grammar = len(grammar_exprs)
|
||||||
|
|
||||||
# 分别排序
|
# 分别排序
|
||||||
style_exprs.sort(key=lambda x: x.get("count", 0), reverse=True)
|
style_exprs.sort(key=lambda x: x.get("count", 0), reverse=True)
|
||||||
grammar_exprs.sort(key=lambda x: x.get("count", 0), reverse=True)
|
grammar_exprs.sort(key=lambda x: x.get("count", 0), reverse=True)
|
||||||
|
|
||||||
# 只取前20条
|
# 只取前20条
|
||||||
style_exprs = style_exprs[:20]
|
style_exprs = style_exprs[:20]
|
||||||
grammar_exprs = grammar_exprs[:20]
|
grammar_exprs = grammar_exprs[:20]
|
||||||
|
|
||||||
# 获取群组名称
|
# 获取群组名称
|
||||||
group_name = get_group_name(chat_id)
|
group_name = get_group_name(chat_id)
|
||||||
|
|
||||||
# 创建群组子目录(使用清理后的名称)
|
# 创建群组子目录(使用清理后的名称)
|
||||||
safe_group_name = clean_group_name(group_name)
|
safe_group_name = clean_group_name(group_name)
|
||||||
group_dir = os.path.join(output_dir, f"{safe_group_name}_{chat_id}")
|
group_dir = os.path.join(output_dir, f"{safe_group_name}_{chat_id}")
|
||||||
os.makedirs(group_dir, exist_ok=True)
|
os.makedirs(group_dir, exist_ok=True)
|
||||||
|
|
||||||
# 写入群组详细报告
|
# 写入群组详细报告
|
||||||
group_file = os.path.join(group_dir, f"expressions_{timestamp}.txt")
|
group_file = os.path.join(group_dir, f"expressions_{timestamp}.txt")
|
||||||
write_group_report(group_file, group_name, chat_id, style_exprs, grammar_exprs)
|
write_group_report(group_file, group_name, chat_id, style_exprs, grammar_exprs)
|
||||||
|
|
||||||
# 写入总报告摘要
|
# 写入总报告摘要
|
||||||
f.write(f"群组: {group_name} (ID: {chat_id})\n")
|
f.write(f"群组: {group_name} (ID: {chat_id})\n")
|
||||||
f.write("-" * 40 + "\n")
|
f.write("-" * 40 + "\n")
|
||||||
@@ -191,11 +204,12 @@ def analyze_expressions():
|
|||||||
f.write(f"句法特点总数: {total_grammar} (显示前20条)\n")
|
f.write(f"句法特点总数: {total_grammar} (显示前20条)\n")
|
||||||
f.write(f"详细报告: {group_file}\n")
|
f.write(f"详细报告: {group_file}\n")
|
||||||
f.write("-" * 40 + "\n\n")
|
f.write("-" * 40 + "\n\n")
|
||||||
|
|
||||||
print("分析报告已生成:")
|
print("分析报告已生成:")
|
||||||
print(f"总报告: {summary_file}")
|
print(f"总报告: {summary_file}")
|
||||||
print(f"人格表达报告: {personality_report}")
|
print(f"人格表达报告: {personality_report}")
|
||||||
print(f"各群组详细报告位于: {output_dir}")
|
print(f"各群组详细报告位于: {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
analyze_expressions()
|
analyze_expressions()
|
||||||
|
|||||||
@@ -71,14 +71,14 @@ def analyze_group_similarity():
|
|||||||
# 获取所有群组目录
|
# 获取所有群组目录
|
||||||
base_dir = Path("data/expression/learnt_style")
|
base_dir = Path("data/expression/learnt_style")
|
||||||
group_dirs = [d for d in base_dir.iterdir() if d.is_dir()]
|
group_dirs = [d for d in base_dir.iterdir() if d.is_dir()]
|
||||||
|
|
||||||
# 加载所有群组的数据并过滤
|
# 加载所有群组的数据并过滤
|
||||||
valid_groups = []
|
valid_groups = []
|
||||||
valid_names = []
|
valid_names = []
|
||||||
valid_situations = []
|
valid_situations = []
|
||||||
valid_styles = []
|
valid_styles = []
|
||||||
valid_combined = []
|
valid_combined = []
|
||||||
|
|
||||||
for d in group_dirs:
|
for d in group_dirs:
|
||||||
situations, styles, combined, total_count = load_group_data(d)
|
situations, styles, combined, total_count = load_group_data(d)
|
||||||
if total_count >= 50: # 只保留数据量大于等于50的群组
|
if total_count >= 50: # 只保留数据量大于等于50的群组
|
||||||
@@ -87,11 +87,11 @@ def analyze_group_similarity():
|
|||||||
valid_situations.append(" ".join(situations))
|
valid_situations.append(" ".join(situations))
|
||||||
valid_styles.append(" ".join(styles))
|
valid_styles.append(" ".join(styles))
|
||||||
valid_combined.append(" ".join(combined))
|
valid_combined.append(" ".join(combined))
|
||||||
|
|
||||||
if not valid_groups:
|
if not valid_groups:
|
||||||
print("没有找到数据量大于等于50的群组")
|
print("没有找到数据量大于等于50的群组")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 创建TF-IDF向量化器
|
# 创建TF-IDF向量化器
|
||||||
vectorizer = TfidfVectorizer()
|
vectorizer = TfidfVectorizer()
|
||||||
|
|
||||||
|
|||||||
@@ -3,117 +3,123 @@ import json
|
|||||||
import random
|
import random
|
||||||
from typing import List, Dict, Tuple
|
from typing import List, Dict, Tuple
|
||||||
import glob
|
import glob
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
MAX_EXPRESSION_COUNT = 300 # 每个群最多保留的表达方式数量
|
MAX_EXPRESSION_COUNT = 300 # 每个群最多保留的表达方式数量
|
||||||
MIN_COUNT_THRESHOLD = 0.01 # 最小使用次数阈值
|
MIN_COUNT_THRESHOLD = 0.01 # 最小使用次数阈值
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]:
|
def load_expressions(chat_id: str) -> Tuple[List[Dict], List[Dict]]:
|
||||||
"""加载指定群聊的表达方式"""
|
"""加载指定群聊的表达方式"""
|
||||||
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||||
grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||||
|
|
||||||
style_exprs = []
|
style_exprs = []
|
||||||
grammar_exprs = []
|
grammar_exprs = []
|
||||||
|
|
||||||
if os.path.exists(style_file):
|
if os.path.exists(style_file):
|
||||||
with open(style_file, "r", encoding="utf-8") as f:
|
with open(style_file, "r", encoding="utf-8") as f:
|
||||||
style_exprs = json.load(f)
|
style_exprs = json.load(f)
|
||||||
|
|
||||||
if os.path.exists(grammar_file):
|
if os.path.exists(grammar_file):
|
||||||
with open(grammar_file, "r", encoding="utf-8") as f:
|
with open(grammar_file, "r", encoding="utf-8") as f:
|
||||||
grammar_exprs = json.load(f)
|
grammar_exprs = json.load(f)
|
||||||
|
|
||||||
return style_exprs, grammar_exprs
|
return style_exprs, grammar_exprs
|
||||||
|
|
||||||
|
|
||||||
def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[Dict]) -> None:
|
def save_expressions(chat_id: str, style_exprs: List[Dict], grammar_exprs: List[Dict]) -> None:
|
||||||
"""保存表达方式到文件"""
|
"""保存表达方式到文件"""
|
||||||
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||||
grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
grammar_file = os.path.join("data", "expression", "learnt_grammar", str(chat_id), "expressions.json")
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(style_file), exist_ok=True)
|
os.makedirs(os.path.dirname(style_file), exist_ok=True)
|
||||||
os.makedirs(os.path.dirname(grammar_file), exist_ok=True)
|
os.makedirs(os.path.dirname(grammar_file), exist_ok=True)
|
||||||
|
|
||||||
with open(style_file, "w", encoding="utf-8") as f:
|
with open(style_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(style_exprs, f, ensure_ascii=False, indent=2)
|
json.dump(style_exprs, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
with open(grammar_file, "w", encoding="utf-8") as f:
|
with open(grammar_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(grammar_exprs, f, ensure_ascii=False, indent=2)
|
json.dump(grammar_exprs, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def cleanup_expressions(expressions: List[Dict]) -> List[Dict]:
|
def cleanup_expressions(expressions: List[Dict]) -> List[Dict]:
|
||||||
"""清理表达方式列表"""
|
"""清理表达方式列表"""
|
||||||
if not expressions:
|
if not expressions:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 1. 移除使用次数过低的表达方式
|
# 1. 移除使用次数过低的表达方式
|
||||||
expressions = [expr for expr in expressions if expr.get("count", 0) > MIN_COUNT_THRESHOLD]
|
expressions = [expr for expr in expressions if expr.get("count", 0) > MIN_COUNT_THRESHOLD]
|
||||||
|
|
||||||
# 2. 如果数量超过限制,随机删除多余的
|
# 2. 如果数量超过限制,随机删除多余的
|
||||||
if len(expressions) > MAX_EXPRESSION_COUNT:
|
if len(expressions) > MAX_EXPRESSION_COUNT:
|
||||||
# 按使用次数排序
|
# 按使用次数排序
|
||||||
expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
|
expressions.sort(key=lambda x: x.get("count", 0), reverse=True)
|
||||||
|
|
||||||
# 保留前50%的高频表达方式
|
# 保留前50%的高频表达方式
|
||||||
keep_count = MAX_EXPRESSION_COUNT // 2
|
keep_count = MAX_EXPRESSION_COUNT // 2
|
||||||
keep_exprs = expressions[:keep_count]
|
keep_exprs = expressions[:keep_count]
|
||||||
|
|
||||||
# 从剩余的表达方式中随机选择
|
# 从剩余的表达方式中随机选择
|
||||||
remaining_exprs = expressions[keep_count:]
|
remaining_exprs = expressions[keep_count:]
|
||||||
random.shuffle(remaining_exprs)
|
random.shuffle(remaining_exprs)
|
||||||
keep_exprs.extend(remaining_exprs[:MAX_EXPRESSION_COUNT - keep_count])
|
keep_exprs.extend(remaining_exprs[: MAX_EXPRESSION_COUNT - keep_count])
|
||||||
|
|
||||||
expressions = keep_exprs
|
expressions = keep_exprs
|
||||||
|
|
||||||
return expressions
|
return expressions
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 获取所有群聊ID
|
# 获取所有群聊ID
|
||||||
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
||||||
chat_ids = [os.path.basename(d) for d in style_dirs]
|
chat_ids = [os.path.basename(d) for d in style_dirs]
|
||||||
|
|
||||||
if not chat_ids:
|
if not chat_ids:
|
||||||
print("没有找到任何群聊的表达方式数据")
|
print("没有找到任何群聊的表达方式数据")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"开始清理 {len(chat_ids)} 个群聊的表达方式数据...")
|
print(f"开始清理 {len(chat_ids)} 个群聊的表达方式数据...")
|
||||||
|
|
||||||
total_style_before = 0
|
total_style_before = 0
|
||||||
total_style_after = 0
|
total_style_after = 0
|
||||||
total_grammar_before = 0
|
total_grammar_before = 0
|
||||||
total_grammar_after = 0
|
total_grammar_after = 0
|
||||||
|
|
||||||
for chat_id in chat_ids:
|
for chat_id in chat_ids:
|
||||||
print(f"\n处理群聊 {chat_id}:")
|
print(f"\n处理群聊 {chat_id}:")
|
||||||
|
|
||||||
# 加载表达方式
|
# 加载表达方式
|
||||||
style_exprs, grammar_exprs = load_expressions(chat_id)
|
style_exprs, grammar_exprs = load_expressions(chat_id)
|
||||||
|
|
||||||
# 记录清理前的数量
|
# 记录清理前的数量
|
||||||
style_count_before = len(style_exprs)
|
style_count_before = len(style_exprs)
|
||||||
grammar_count_before = len(grammar_exprs)
|
grammar_count_before = len(grammar_exprs)
|
||||||
total_style_before += style_count_before
|
total_style_before += style_count_before
|
||||||
total_grammar_before += grammar_count_before
|
total_grammar_before += grammar_count_before
|
||||||
|
|
||||||
# 清理表达方式
|
# 清理表达方式
|
||||||
style_exprs = cleanup_expressions(style_exprs)
|
style_exprs = cleanup_expressions(style_exprs)
|
||||||
grammar_exprs = cleanup_expressions(grammar_exprs)
|
grammar_exprs = cleanup_expressions(grammar_exprs)
|
||||||
|
|
||||||
# 记录清理后的数量
|
# 记录清理后的数量
|
||||||
style_count_after = len(style_exprs)
|
style_count_after = len(style_exprs)
|
||||||
grammar_count_after = len(grammar_exprs)
|
grammar_count_after = len(grammar_exprs)
|
||||||
total_style_after += style_count_after
|
total_style_after += style_count_after
|
||||||
total_grammar_after += grammar_count_after
|
total_grammar_after += grammar_count_after
|
||||||
|
|
||||||
# 保存清理后的表达方式
|
# 保存清理后的表达方式
|
||||||
save_expressions(chat_id, style_exprs, grammar_exprs)
|
save_expressions(chat_id, style_exprs, grammar_exprs)
|
||||||
|
|
||||||
print(f"语言风格: {style_count_before} -> {style_count_after}")
|
print(f"语言风格: {style_count_before} -> {style_count_after}")
|
||||||
print(f"句法特点: {grammar_count_before} -> {grammar_count_after}")
|
print(f"句法特点: {grammar_count_before} -> {grammar_count_after}")
|
||||||
|
|
||||||
print("\n清理完成!")
|
print("\n清理完成!")
|
||||||
print(f"语言风格总数: {total_style_before} -> {total_style_after}")
|
print(f"语言风格总数: {total_style_before} -> {total_style_after}")
|
||||||
print(f"句法特点总数: {total_grammar_before} -> {total_grammar_after}")
|
print(f"句法特点总数: {total_grammar_before} -> {total_grammar_after}")
|
||||||
print(f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式")
|
print(
|
||||||
|
f"总共清理了 {total_style_before + total_grammar_before - total_style_after - total_grammar_after} 条表达方式"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -15,13 +16,15 @@ import random
|
|||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
def clean_group_name(name: str) -> str:
|
def clean_group_name(name: str) -> str:
|
||||||
"""清理群组名称,只保留中文和英文字符"""
|
"""清理群组名称,只保留中文和英文字符"""
|
||||||
cleaned = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', '', name)
|
cleaned = re.sub(r"[^\u4e00-\u9fa5a-zA-Z]", "", name)
|
||||||
if not cleaned:
|
if not cleaned:
|
||||||
cleaned = datetime.now().strftime("%Y%m%d")
|
cleaned = datetime.now().strftime("%Y%m%d")
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
def get_group_name(stream_id: str) -> str:
|
def get_group_name(stream_id: str) -> str:
|
||||||
"""从数据库中获取群组名称"""
|
"""从数据库中获取群组名称"""
|
||||||
conn = sqlite3.connect("data/maibot.db")
|
conn = sqlite3.connect("data/maibot.db")
|
||||||
@@ -49,76 +52,79 @@ def get_group_name(stream_id: str) -> str:
|
|||||||
return clean_group_name(f"{platform}{stream_id[:8]}")
|
return clean_group_name(f"{platform}{stream_id[:8]}")
|
||||||
return stream_id
|
return stream_id
|
||||||
|
|
||||||
|
|
||||||
def load_expressions(chat_id: str) -> List[Dict]:
|
def load_expressions(chat_id: str) -> List[Dict]:
|
||||||
"""加载指定群聊的表达方式"""
|
"""加载指定群聊的表达方式"""
|
||||||
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
style_file = os.path.join("data", "expression", "learnt_style", str(chat_id), "expressions.json")
|
||||||
|
|
||||||
style_exprs = []
|
style_exprs = []
|
||||||
|
|
||||||
if os.path.exists(style_file):
|
if os.path.exists(style_file):
|
||||||
with open(style_file, "r", encoding="utf-8") as f:
|
with open(style_file, "r", encoding="utf-8") as f:
|
||||||
style_exprs = json.load(f)
|
style_exprs = json.load(f)
|
||||||
|
|
||||||
# 如果表达方式超过10个,随机选择10个
|
# 如果表达方式超过10个,随机选择10个
|
||||||
if len(style_exprs) > 50:
|
if len(style_exprs) > 50:
|
||||||
style_exprs = random.sample(style_exprs, 50)
|
style_exprs = random.sample(style_exprs, 50)
|
||||||
print(f"\n从 {len(style_exprs)} 个表达方式中随机选择了 10 个进行匹配")
|
print(f"\n从 {len(style_exprs)} 个表达方式中随机选择了 10 个进行匹配")
|
||||||
|
|
||||||
return style_exprs
|
return style_exprs
|
||||||
|
|
||||||
def find_similar_expressions_tfidf(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10) -> List[Tuple[str, str, float]]:
|
|
||||||
|
def find_similar_expressions_tfidf(
|
||||||
|
input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 10
|
||||||
|
) -> List[Tuple[str, str, float]]:
|
||||||
"""使用TF-IDF方法找出与输入文本最相似的top_k个表达方式"""
|
"""使用TF-IDF方法找出与输入文本最相似的top_k个表达方式"""
|
||||||
if not expressions:
|
if not expressions:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 准备文本数据
|
# 准备文本数据
|
||||||
if mode == "style":
|
if mode == "style":
|
||||||
texts = [expr['style'] for expr in expressions]
|
texts = [expr["style"] for expr in expressions]
|
||||||
elif mode == "situation":
|
elif mode == "situation":
|
||||||
texts = [expr['situation'] for expr in expressions]
|
texts = [expr["situation"] for expr in expressions]
|
||||||
else: # both
|
else: # both
|
||||||
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
|
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
|
||||||
|
|
||||||
texts.append(input_text) # 添加输入文本
|
texts.append(input_text) # 添加输入文本
|
||||||
|
|
||||||
# 使用TF-IDF向量化
|
# 使用TF-IDF向量化
|
||||||
vectorizer = TfidfVectorizer()
|
vectorizer = TfidfVectorizer()
|
||||||
tfidf_matrix = vectorizer.fit_transform(texts)
|
tfidf_matrix = vectorizer.fit_transform(texts)
|
||||||
|
|
||||||
# 计算余弦相似度
|
# 计算余弦相似度
|
||||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||||
|
|
||||||
# 获取输入文本的相似度分数(最后一行)
|
# 获取输入文本的相似度分数(最后一行)
|
||||||
scores = similarity_matrix[-1][:-1] # 排除与自身的相似度
|
scores = similarity_matrix[-1][:-1] # 排除与自身的相似度
|
||||||
|
|
||||||
# 获取top_k的索引
|
# 获取top_k的索引
|
||||||
top_indices = np.argsort(scores)[::-1][:top_k]
|
top_indices = np.argsort(scores)[::-1][:top_k]
|
||||||
|
|
||||||
# 获取相似表达
|
# 获取相似表达
|
||||||
similar_exprs = []
|
similar_exprs = []
|
||||||
for idx in top_indices:
|
for idx in top_indices:
|
||||||
if scores[idx] > 0: # 只保留有相似度的
|
if scores[idx] > 0: # 只保留有相似度的
|
||||||
similar_exprs.append((
|
similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], scores[idx]))
|
||||||
expressions[idx]['style'],
|
|
||||||
expressions[idx]['situation'],
|
|
||||||
scores[idx]
|
|
||||||
))
|
|
||||||
|
|
||||||
return similar_exprs
|
return similar_exprs
|
||||||
|
|
||||||
async def find_similar_expressions_embedding(input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5) -> List[Tuple[str, str, float]]:
|
|
||||||
|
async def find_similar_expressions_embedding(
|
||||||
|
input_text: str, expressions: List[Dict], mode: str = "both", top_k: int = 5
|
||||||
|
) -> List[Tuple[str, str, float]]:
|
||||||
"""使用嵌入模型找出与输入文本最相似的top_k个表达方式"""
|
"""使用嵌入模型找出与输入文本最相似的top_k个表达方式"""
|
||||||
if not expressions:
|
if not expressions:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 准备文本数据
|
# 准备文本数据
|
||||||
if mode == "style":
|
if mode == "style":
|
||||||
texts = [expr['style'] for expr in expressions]
|
texts = [expr["style"] for expr in expressions]
|
||||||
elif mode == "situation":
|
elif mode == "situation":
|
||||||
texts = [expr['situation'] for expr in expressions]
|
texts = [expr["situation"] for expr in expressions]
|
||||||
else: # both
|
else: # both
|
||||||
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
|
texts = [f"{expr['situation']} {expr['style']}" for expr in expressions]
|
||||||
|
|
||||||
# 获取嵌入向量
|
# 获取嵌入向量
|
||||||
llm_request = LLMRequest(global_config.model.embedding)
|
llm_request = LLMRequest(global_config.model.embedding)
|
||||||
text_embeddings = []
|
text_embeddings = []
|
||||||
@@ -126,73 +132,70 @@ async def find_similar_expressions_embedding(input_text: str, expressions: List[
|
|||||||
embedding = await llm_request.get_embedding(text)
|
embedding = await llm_request.get_embedding(text)
|
||||||
if embedding:
|
if embedding:
|
||||||
text_embeddings.append(embedding)
|
text_embeddings.append(embedding)
|
||||||
|
|
||||||
input_embedding = await llm_request.get_embedding(input_text)
|
input_embedding = await llm_request.get_embedding(input_text)
|
||||||
if not input_embedding or not text_embeddings:
|
if not input_embedding or not text_embeddings:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 计算余弦相似度
|
# 计算余弦相似度
|
||||||
text_embeddings = np.array(text_embeddings)
|
text_embeddings = np.array(text_embeddings)
|
||||||
similarities = np.dot(text_embeddings, input_embedding) / (
|
similarities = np.dot(text_embeddings, input_embedding) / (
|
||||||
np.linalg.norm(text_embeddings, axis=1) * np.linalg.norm(input_embedding)
|
np.linalg.norm(text_embeddings, axis=1) * np.linalg.norm(input_embedding)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取top_k的索引
|
# 获取top_k的索引
|
||||||
top_indices = np.argsort(similarities)[::-1][:top_k]
|
top_indices = np.argsort(similarities)[::-1][:top_k]
|
||||||
|
|
||||||
# 获取相似表达
|
# 获取相似表达
|
||||||
similar_exprs = []
|
similar_exprs = []
|
||||||
for idx in top_indices:
|
for idx in top_indices:
|
||||||
if similarities[idx] > 0: # 只保留有相似度的
|
if similarities[idx] > 0: # 只保留有相似度的
|
||||||
similar_exprs.append((
|
similar_exprs.append((expressions[idx]["style"], expressions[idx]["situation"], similarities[idx]))
|
||||||
expressions[idx]['style'],
|
|
||||||
expressions[idx]['situation'],
|
|
||||||
similarities[idx]
|
|
||||||
))
|
|
||||||
|
|
||||||
return similar_exprs
|
return similar_exprs
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# 获取所有群聊ID
|
# 获取所有群聊ID
|
||||||
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
style_dirs = glob.glob(os.path.join("data", "expression", "learnt_style", "*"))
|
||||||
chat_ids = [os.path.basename(d) for d in style_dirs]
|
chat_ids = [os.path.basename(d) for d in style_dirs]
|
||||||
|
|
||||||
if not chat_ids:
|
if not chat_ids:
|
||||||
print("没有找到任何群聊的表达方式数据")
|
print("没有找到任何群聊的表达方式数据")
|
||||||
return
|
return
|
||||||
|
|
||||||
print("可用的群聊:")
|
print("可用的群聊:")
|
||||||
for i, chat_id in enumerate(chat_ids, 1):
|
for i, chat_id in enumerate(chat_ids, 1):
|
||||||
group_name = get_group_name(chat_id)
|
group_name = get_group_name(chat_id)
|
||||||
print(f"{i}. {group_name}")
|
print(f"{i}. {group_name}")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
choice = int(input("\n请选择要分析的群聊编号 (输入0退出): "))
|
choice = int(input("\n请选择要分析的群聊编号 (输入0退出): "))
|
||||||
if choice == 0:
|
if choice == 0:
|
||||||
break
|
break
|
||||||
if 1 <= choice <= len(chat_ids):
|
if 1 <= choice <= len(chat_ids):
|
||||||
chat_id = chat_ids[choice-1]
|
chat_id = chat_ids[choice - 1]
|
||||||
break
|
break
|
||||||
print("无效的选择,请重试")
|
print("无效的选择,请重试")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print("请输入有效的数字")
|
print("请输入有效的数字")
|
||||||
|
|
||||||
if choice == 0:
|
if choice == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 加载表达方式
|
# 加载表达方式
|
||||||
style_exprs = load_expressions(chat_id)
|
style_exprs = load_expressions(chat_id)
|
||||||
|
|
||||||
group_name = get_group_name(chat_id)
|
group_name = get_group_name(chat_id)
|
||||||
print(f"\n已选择群聊:{group_name}")
|
print(f"\n已选择群聊:{group_name}")
|
||||||
|
|
||||||
# 选择匹配模式
|
# 选择匹配模式
|
||||||
print("\n请选择匹配模式:")
|
print("\n请选择匹配模式:")
|
||||||
print("1. 匹配表达方式")
|
print("1. 匹配表达方式")
|
||||||
print("2. 匹配情景")
|
print("2. 匹配情景")
|
||||||
print("3. 两者都考虑")
|
print("3. 两者都考虑")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
mode_choice = int(input("\n请选择匹配模式 (1-3): "))
|
mode_choice = int(input("\n请选择匹配模式 (1-3): "))
|
||||||
@@ -201,19 +204,15 @@ async def main():
|
|||||||
print("无效的选择,请重试")
|
print("无效的选择,请重试")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print("请输入有效的数字")
|
print("请输入有效的数字")
|
||||||
|
|
||||||
mode_map = {
|
mode_map = {1: "style", 2: "situation", 3: "both"}
|
||||||
1: "style",
|
|
||||||
2: "situation",
|
|
||||||
3: "both"
|
|
||||||
}
|
|
||||||
mode = mode_map[mode_choice]
|
mode = mode_map[mode_choice]
|
||||||
|
|
||||||
# 选择匹配方法
|
# 选择匹配方法
|
||||||
print("\n请选择匹配方法:")
|
print("\n请选择匹配方法:")
|
||||||
print("1. TF-IDF方法")
|
print("1. TF-IDF方法")
|
||||||
print("2. 嵌入模型方法")
|
print("2. 嵌入模型方法")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
method_choice = int(input("\n请选择匹配方法 (1-2): "))
|
method_choice = int(input("\n请选择匹配方法 (1-2): "))
|
||||||
@@ -222,20 +221,20 @@ async def main():
|
|||||||
print("无效的选择,请重试")
|
print("无效的选择,请重试")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print("请输入有效的数字")
|
print("请输入有效的数字")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
input_text = input("\n请输入要匹配的文本(输入q退出): ")
|
input_text = input("\n请输入要匹配的文本(输入q退出): ")
|
||||||
if input_text.lower() == 'q':
|
if input_text.lower() == "q":
|
||||||
break
|
break
|
||||||
|
|
||||||
if not input_text.strip():
|
if not input_text.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if method_choice == 1:
|
if method_choice == 1:
|
||||||
similar_exprs = find_similar_expressions_tfidf(input_text, style_exprs, mode)
|
similar_exprs = find_similar_expressions_tfidf(input_text, style_exprs, mode)
|
||||||
else:
|
else:
|
||||||
similar_exprs = await find_similar_expressions_embedding(input_text, style_exprs, mode)
|
similar_exprs = await find_similar_expressions_embedding(input_text, style_exprs, mode)
|
||||||
|
|
||||||
if similar_exprs:
|
if similar_exprs:
|
||||||
print("\n找到以下相似表达:")
|
print("\n找到以下相似表达:")
|
||||||
for style, situation, score in similar_exprs:
|
for style, situation, score in similar_exprs:
|
||||||
@@ -246,6 +245,8 @@ async def main():
|
|||||||
else:
|
else:
|
||||||
print("\n没有找到相似的表达方式")
|
print("\n没有找到相似的表达方式")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import asyncio
|
import asyncio
|
||||||
asyncio.run(main())
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
|||||||
from typing import Optional, Tuple, Dict
|
from typing import Optional, Tuple, Dict
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
from src.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||||
|
|
||||||
logger = get_logger("heartflow_utils")
|
logger = get_logger("heartflow_utils")
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
|||||||
# Try to fetch person info
|
# Try to fetch person info
|
||||||
try:
|
try:
|
||||||
# Assume get_person_id is sync (as per original code), keep using to_thread
|
# Assume get_person_id is sync (as per original code), keep using to_thread
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
person_name = None
|
person_name = None
|
||||||
if person_id:
|
if person_id:
|
||||||
# get_value is async, so await it directly
|
# get_value is async, so await it directly
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageThinking
|
|||||||
from src.chat.normal_chat.normal_prompt import prompt_builder
|
from src.chat.normal_chat.normal_prompt import prompt_builder
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||||
from src.chat.utils.utils import process_llm_response
|
from src.chat.utils.utils import process_llm_response
|
||||||
|
|
||||||
|
|
||||||
@@ -66,7 +66,7 @@ class NormalChatGenerator:
|
|||||||
enable_planner: bool = False,
|
enable_planner: bool = False,
|
||||||
available_actions=None,
|
available_actions=None,
|
||||||
):
|
):
|
||||||
person_id = person_info_manager.get_person_id(
|
person_id = PersonInfoManager.get_person_id(
|
||||||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class BaseWillingManager(ABC):
|
|||||||
self.logger: LoguruLogger = logger
|
self.logger: LoguruLogger = logger
|
||||||
|
|
||||||
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
|
def setup(self, message: MessageRecv, chat: ChatStream, is_mentioned_bot: bool, interested_rate: float):
|
||||||
person_id = person_info_manager.get_person_id(chat.platform, chat.user_info.user_id)
|
person_id = PersonInfoManager.get_person_id(chat.platform, chat.user_info.user_id)
|
||||||
self.ongoing_messages[message.message_info.message_id] = WillingInfo(
|
self.ongoing_messages[message.message_info.message_id] = WillingInfo(
|
||||||
message=message,
|
message=message,
|
||||||
chat=chat,
|
chat=chat,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import time # 导入 time 模块以获取当前时间
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from src.common.message_repository import find_messages, count_messages
|
from src.common.message_repository import find_messages, count_messages
|
||||||
from src.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from src.common.database.database_model import ActionRecords
|
from src.common.database.database_model import ActionRecords
|
||||||
@@ -219,7 +219,7 @@ def _build_readable_messages_internal(
|
|||||||
if not all([platform, user_id, timestamp is not None]):
|
if not all([platform, user_id, timestamp is not None]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
# 根据 replace_bot_name 参数决定是否替换机器人名称
|
||||||
if replace_bot_name and user_id == global_config.bot.qq_account:
|
if replace_bot_name and user_id == global_config.bot.qq_account:
|
||||||
person_name = f"{global_config.bot.nickname}(你)"
|
person_name = f"{global_config.bot.nickname}(你)"
|
||||||
@@ -241,7 +241,7 @@ def _build_readable_messages_internal(
|
|||||||
if match:
|
if match:
|
||||||
aaa = match.group(1)
|
aaa = match.group(1)
|
||||||
bbb = match.group(2)
|
bbb = match.group(2)
|
||||||
reply_person_id = person_info_manager.get_person_id(platform, bbb)
|
reply_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||||
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
|
reply_person_name = person_info_manager.get_value_sync(reply_person_id, "person_name")
|
||||||
if not reply_person_name:
|
if not reply_person_name:
|
||||||
reply_person_name = aaa
|
reply_person_name = aaa
|
||||||
@@ -258,7 +258,7 @@ def _build_readable_messages_internal(
|
|||||||
new_content += content[last_end : m.start()]
|
new_content += content[last_end : m.start()]
|
||||||
aaa = m.group(1)
|
aaa = m.group(1)
|
||||||
bbb = m.group(2)
|
bbb = m.group(2)
|
||||||
at_person_id = person_info_manager.get_person_id(platform, bbb)
|
at_person_id = PersonInfoManager.get_person_id(platform, bbb)
|
||||||
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
|
at_person_name = person_info_manager.get_value_sync(at_person_id, "person_name")
|
||||||
if not at_person_name:
|
if not at_person_name:
|
||||||
at_person_name = aaa
|
at_person_name = aaa
|
||||||
@@ -572,7 +572,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
# print("SELF11111111111111")
|
# print("SELF11111111111111")
|
||||||
return "SELF"
|
return "SELF"
|
||||||
try:
|
try:
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
person_id = None
|
person_id = None
|
||||||
if not person_id:
|
if not person_id:
|
||||||
@@ -673,7 +673,7 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
|
|
||||||
# 只有当获取到有效 person_id 时才添加
|
# 只有当获取到有效 person_id 时才添加
|
||||||
if person_id:
|
if person_id:
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
from src.manager.async_task_manager import AsyncTask
|
from src.manager.async_task_manager import AsyncTask
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
from src.person_info.person_info import PersonInfoManager
|
||||||
from src.person_info.relationship_manager import relationship_manager
|
from src.person_info.relationship_manager import relationship_manager
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.person_info.person_info import person_info_manager
|
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
@@ -95,7 +95,7 @@ class ImpressionUpdateTask(AsyncTask):
|
|||||||
if msg["user_nickname"] == global_config.bot.nickname:
|
if msg["user_nickname"] == global_config.bot.nickname:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
person_id = person_info_manager.get_person_id(msg["chat_info_platform"], msg["user_id"])
|
person_id = PersonInfoManager.get_person_id(msg["chat_info_platform"], msg["user_id"])
|
||||||
if not person_id:
|
if not person_id:
|
||||||
logger.warning(f"未找到用户 {msg['user_nickname']} 的person_id")
|
logger.warning(f"未找到用户 {msg['user_nickname']} 的person_id")
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
import math
|
import math
|
||||||
from src.person_info.person_info import person_info_manager
|
from src.person_info.person_info import person_info_manager, PersonInfoManager
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
@@ -91,7 +91,7 @@ class RelationshipManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
||||||
"""判断是否认识某人"""
|
"""判断是否认识某人"""
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
# 生成唯一的 person_name
|
# 生成唯一的 person_name
|
||||||
unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname)
|
unique_nickname = await person_info_manager._generate_unique_person_name(user_nickname)
|
||||||
data = {
|
data = {
|
||||||
@@ -116,7 +116,7 @@ class RelationshipManager:
|
|||||||
if is_id:
|
if is_id:
|
||||||
person_id = person
|
person_id = person
|
||||||
else:
|
else:
|
||||||
person_id = person_info_manager.get_person_id(person[0], person[1])
|
person_id = PersonInfoManager.get_person_id(person[0], person[1])
|
||||||
|
|
||||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||||
if not person_name or person_name == "none":
|
if not person_name or person_name == "none":
|
||||||
@@ -198,7 +198,7 @@ class RelationshipManager:
|
|||||||
)
|
)
|
||||||
replace_user_id = msg.get("user_id")
|
replace_user_id = msg.get("user_id")
|
||||||
replace_platform = msg.get("chat_info_platform")
|
replace_platform = msg.get("chat_info_platform")
|
||||||
replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id)
|
replace_person_id = PersonInfoManager.get_person_id(replace_platform, replace_user_id)
|
||||||
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
||||||
|
|
||||||
# 跳过机器人自己
|
# 跳过机器人自己
|
||||||
@@ -467,7 +467,7 @@ class RelationshipManager:
|
|||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
user_id = msg.get("user_id")
|
user_id = msg.get("user_id")
|
||||||
platform = msg.get("chat_info_platform")
|
platform = msg.get("chat_info_platform")
|
||||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||||
if person_id == target_person_id:
|
if person_id == target_person_id:
|
||||||
target_indices.append(i)
|
target_indices.append(i)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user