refactor: 全部代码格式化

This commit is contained in:
Rikki
2025-03-30 04:56:46 +08:00
parent 7adaa2f5a8
commit b2fc824afd
21 changed files with 491 additions and 514 deletions

View File

@@ -35,6 +35,7 @@ else:
print(f"未找到环境变量文件: {env_path}")
print("将使用默认配置")
class ChatBasedPersonalityEvaluator:
def __init__(self):
self.personality_traits = {"开放性": 0, "严谨性": 0, "外向性": 0, "宜人性": 0, "神经质": 0}
@@ -50,16 +51,14 @@ class ChatBasedPersonalityEvaluator:
continue
scene_keys = list(scenes.keys())
selected_scenes = random.sample(scene_keys, min(3, len(scene_keys)))
for scene_key in selected_scenes:
scene = scenes[scene_key]
other_traits = [t for t in PERSONALITY_SCENES if t != trait]
secondary_trait = random.choice(other_traits)
self.scenarios.append({
"场景": scene["scenario"],
"评估维度": [trait, secondary_trait],
"场景编号": scene_key
})
self.scenarios.append(
{"场景": scene["scenario"], "评估维度": [trait, secondary_trait], "场景编号": scene_key}
)
def analyze_chat_context(self, messages: List[Dict]) -> str:
"""
@@ -67,20 +66,21 @@ class ChatBasedPersonalityEvaluator:
"""
context = ""
for msg in messages:
nickname = msg.get('user_info', {}).get('user_nickname', '未知用户')
content = msg.get('processed_plain_text', msg.get('detailed_plain_text', ''))
nickname = msg.get("user_info", {}).get("user_nickname", "未知用户")
content = msg.get("processed_plain_text", msg.get("detailed_plain_text", ""))
if content:
context += f"{nickname}: {content}\n"
return context
def evaluate_chat_response(
self, user_nickname: str, chat_context: str, dimensions: List[str] = None) -> Dict[str, float]:
self, user_nickname: str, chat_context: str, dimensions: List[str] = None
) -> Dict[str, float]:
"""
评估聊天内容在各个人格维度上的得分
"""
# 使用所有维度进行评估
dimensions = list(self.personality_traits.keys())
dimension_descriptions = []
for dim in dimensions:
desc = FACTOR_DESCRIPTIONS.get(dim, "")
@@ -136,18 +136,19 @@ class ChatBasedPersonalityEvaluator:
def evaluate_user_personality(self, qq_id: str, num_samples: int = 10, context_length: int = 5) -> Dict:
"""
基于用户的聊天记录评估人格特征
Args:
qq_id (str): 用户QQ号
num_samples (int): 要分析的聊天片段数量
context_length (int): 每个聊天片段的上下文长度
Returns:
Dict: 评估结果
"""
# 获取用户的随机消息及其上下文
chat_contexts, user_nickname = self.message_analyzer.get_user_random_contexts(
qq_id, num_messages=num_samples, context_length=context_length)
qq_id, num_messages=num_samples, context_length=context_length
)
if not chat_contexts:
return {"error": f"没有找到QQ号 {qq_id} 的消息记录"}
@@ -155,7 +156,7 @@ class ChatBasedPersonalityEvaluator:
final_scores = defaultdict(float)
dimension_counts = defaultdict(int)
chat_samples = []
# 清空历史记录
self.trait_scores_history.clear()
@@ -163,13 +164,11 @@ class ChatBasedPersonalityEvaluator:
for chat_context in chat_contexts:
# 评估这段聊天内容的所有维度
scores = self.evaluate_chat_response(user_nickname, chat_context)
# 记录样本
chat_samples.append({
"聊天内容": chat_context,
"评估维度": list(self.personality_traits.keys()),
"评分": scores
})
chat_samples.append(
{"聊天内容": chat_context, "评估维度": list(self.personality_traits.keys()), "评分": scores}
)
# 更新总分和历史记录
for dimension, score in scores.items():
@@ -196,7 +195,7 @@ class ChatBasedPersonalityEvaluator:
"人格特征评分": average_scores,
"维度评估次数": dict(dimension_counts),
"详细样本": chat_samples,
"特质得分历史": {k: v for k, v in self.trait_scores_history.items()}
"特质得分历史": {k: v for k, v in self.trait_scores_history.items()},
}
# 保存结果
@@ -215,40 +214,41 @@ class ChatBasedPersonalityEvaluator:
chinese_fonts = []
for f in fm.fontManager.ttflist:
try:
if '' in f.name or 'SC' in f.name or '' in f.name or '' in f.name or '微软' in f.name:
if "" in f.name or "SC" in f.name or "" in f.name or "" in f.name or "微软" in f.name:
chinese_fonts.append(f.name)
except Exception:
continue
if chinese_fonts:
plt.rcParams['font.sans-serif'] = chinese_fonts + ['SimHei', 'Microsoft YaHei', 'Arial Unicode MS']
plt.rcParams["font.sans-serif"] = chinese_fonts + ["SimHei", "Microsoft YaHei", "Arial Unicode MS"]
else:
# 如果没有找到中文字体,使用默认字体,并将中文昵称转换为拼音或英文
try:
from pypinyin import lazy_pinyin
user_nickname = ''.join(lazy_pinyin(user_nickname))
user_nickname = "".join(lazy_pinyin(user_nickname))
except ImportError:
user_nickname = "User" # 如果无法转换为拼音,使用默认英文
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
plt.figure(figsize=(12, 6))
plt.style.use('bmh') # 使用内置的bmh样式它有类似seaborn的美观效果
plt.style.use("bmh") # 使用内置的bmh样式它有类似seaborn的美观效果
colors = {
"开放性": "#FF9999",
"严谨性": "#66B2FF",
"外向性": "#99FF99",
"宜人性": "#FFCC99",
"神经质": "#FF99CC"
"神经质": "#FF99CC",
}
# 计算每个维度在每个时间点的累计平均分
cumulative_averages = {}
for trait, scores in self.trait_scores_history.items():
if not scores:
continue
averages = []
total = 0
valid_count = 0
@@ -264,25 +264,25 @@ class ChatBasedPersonalityEvaluator:
averages.append(averages[-1])
else:
continue # 跳过无效分数
if averages: # 只有在有有效分数的情况下才添加到累计平均中
cumulative_averages[trait] = averages
# 绘制每个维度的累计平均分变化趋势
for trait, averages in cumulative_averages.items():
x = range(1, len(averages) + 1)
plt.plot(x, averages, 'o-', label=trait, color=colors.get(trait), linewidth=2, markersize=8)
plt.plot(x, averages, "o-", label=trait, color=colors.get(trait), linewidth=2, markersize=8)
# 添加趋势线
z = np.polyfit(x, averages, 1)
p = np.poly1d(z)
plt.plot(x, p(x), '--', color=colors.get(trait), alpha=0.5)
plt.plot(x, p(x), "--", color=colors.get(trait), alpha=0.5)
plt.title(f"{user_nickname} 的人格特质累计平均分变化趋势", fontsize=14, pad=20)
plt.xlabel("评估次数", fontsize=12)
plt.ylabel("累计平均分", fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.ylim(0, 7)
plt.tight_layout()
@@ -290,38 +290,39 @@ class ChatBasedPersonalityEvaluator:
os.makedirs("results/plots", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
plot_file = f"results/plots/personality_trend_{qq_id}_{timestamp}.png"
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
plt.savefig(plot_file, dpi=300, bbox_inches="tight")
plt.close()
def analyze_user_personality(qq_id: str, num_samples: int = 10, context_length: int = 5) -> str:
"""
分析用户人格特征的便捷函数
Args:
qq_id (str): 用户QQ号
num_samples (int): 要分析的聊天片段数量
context_length (int): 每个聊天片段的上下文长度
Returns:
str: 格式化的分析结果
"""
evaluator = ChatBasedPersonalityEvaluator()
result = evaluator.evaluate_user_personality(qq_id, num_samples, context_length)
if "error" in result:
return result["error"]
# 格式化输出
output = f"QQ号 {qq_id} ({result['用户昵称']}) 的人格特征分析结果:\n"
output += "=" * 50 + "\n\n"
output += "人格特征评分:\n"
for trait, score in result["人格特征评分"].items():
if score == 0:
output += f"{trait}: 数据不足,无法判断 (评估次数: {result['维度评估次数'].get(trait, 0)})\n"
else:
output += f"{trait}: {score}/6 (评估次数: {result['维度评估次数'].get(trait, 0)})\n"
# 添加变化趋势描述
if trait in result["特质得分历史"] and len(result["特质得分历史"][trait]) > 1:
scores = [s for s in result["特质得分历史"][trait] if s != 0] # 过滤掉无效分数
@@ -334,13 +335,14 @@ def analyze_user_personality(qq_id: str, num_samples: int = 10, context_length:
else:
trend_desc = "呈下降趋势"
output += f" 变化趋势: {trend_desc} (斜率: {trend:.2f})\n"
output += f"\n分析样本数量:{result['样本数量']}\n"
output += f"结果已保存至results/personality_result_{qq_id}.json\n"
output += "变化趋势图已保存至results/plots/目录\n"
return output
if __name__ == "__main__":
# 测试代码
# test_qq = "" # 替换为要测试的QQ号

View File

@@ -82,7 +82,6 @@ class PersonalityEvaluator_direct:
dimensions_text = "\n".join(dimension_descriptions)
prompt = f"""请根据以下场景和用户描述评估用户在大五人格模型中的相关维度得分1-6分
场景描述:

View File

@@ -14,18 +14,19 @@ sys.path.append(root_path)
from src.common.database import db # noqa: E402
class MessageAnalyzer:
def __init__(self):
self.messages_collection = db["messages"]
def get_message_context(self, message_id: int, context_length: int = 5) -> Optional[List[Dict]]:
"""
获取指定消息ID的上下文消息列表
Args:
message_id (int): 消息ID
context_length (int): 上下文长度(单侧,总长度为 2*context_length + 1
Returns:
Optional[List[Dict]]: 消息列表如果未找到则返回None
"""
@@ -33,110 +34,110 @@ class MessageAnalyzer:
target_message = self.messages_collection.find_one({"message_id": message_id})
if not target_message:
return None
# 获取该消息的stream_id
stream_id = target_message.get('chat_info', {}).get('stream_id')
stream_id = target_message.get("chat_info", {}).get("stream_id")
if not stream_id:
return None
# 获取同一stream_id的所有消息
stream_messages = list(self.messages_collection.find({
"chat_info.stream_id": stream_id
}).sort("time", 1))
stream_messages = list(self.messages_collection.find({"chat_info.stream_id": stream_id}).sort("time", 1))
# 找到目标消息在列表中的位置
target_index = None
for i, msg in enumerate(stream_messages):
if msg['message_id'] == message_id:
if msg["message_id"] == message_id:
target_index = i
break
if target_index is None:
return None
# 获取目标消息前后的消息
start_index = max(0, target_index - context_length)
end_index = min(len(stream_messages), target_index + context_length + 1)
return stream_messages[start_index:end_index]
def format_messages(self, messages: List[Dict], target_message_id: Optional[int] = None) -> str:
"""
格式化消息列表为可读字符串
Args:
messages (List[Dict]): 消息列表
target_message_id (Optional[int]): 目标消息ID用于标记
Returns:
str: 格式化的消息字符串
"""
if not messages:
return "没有消息记录"
reply = ""
for msg in messages:
# 消息时间
msg_time = datetime.datetime.fromtimestamp(int(msg['time'])).strftime("%Y-%m-%d %H:%M:%S")
msg_time = datetime.datetime.fromtimestamp(int(msg["time"])).strftime("%Y-%m-%d %H:%M:%S")
# 获取消息内容
message_text = msg.get('processed_plain_text', msg.get('detailed_plain_text', '无消息内容'))
nickname = msg.get('user_info', {}).get('user_nickname', '未知用户')
message_text = msg.get("processed_plain_text", msg.get("detailed_plain_text", "无消息内容"))
nickname = msg.get("user_info", {}).get("user_nickname", "未知用户")
# 标记当前消息
is_target = "" if target_message_id and msg['message_id'] == target_message_id else " "
is_target = "" if target_message_id and msg["message_id"] == target_message_id else " "
reply += f"{is_target}[{msg_time}] {nickname}: {message_text}\n"
if target_message_id and msg['message_id'] == target_message_id:
if target_message_id and msg["message_id"] == target_message_id:
reply += " " + "-" * 50 + "\n"
return reply
def get_user_random_contexts(
self, qq_id: str, num_messages: int = 10, context_length: int = 5) -> tuple[List[str], str]: # noqa: E501
self, qq_id: str, num_messages: int = 10, context_length: int = 5
) -> tuple[List[str], str]: # noqa: E501
"""
获取用户的随机消息及其上下文
Args:
qq_id (str): QQ号
num_messages (int): 要获取的随机消息数量
context_length (int): 每条消息的上下文长度(单侧)
Returns:
tuple[List[str], str]: (每个消息上下文的格式化字符串列表, 用户昵称)
"""
if not qq_id:
return [], ""
# 获取用户所有消息
all_messages = list(self.messages_collection.find({"user_info.user_id": int(qq_id)}))
if not all_messages:
return [], ""
# 获取用户昵称
user_nickname = all_messages[0].get('chat_info', {}).get('user_info', {}).get('user_nickname', '未知用户')
user_nickname = all_messages[0].get("chat_info", {}).get("user_info", {}).get("user_nickname", "未知用户")
# 随机选择指定数量的消息
selected_messages = random.sample(all_messages, min(num_messages, len(all_messages)))
# 按时间排序
selected_messages.sort(key=lambda x: int(x['time']))
selected_messages.sort(key=lambda x: int(x["time"]))
# 存储所有上下文消息
context_list = []
# 获取每条消息的上下文
for msg in selected_messages:
message_id = msg['message_id']
message_id = msg["message_id"]
# 获取消息上下文
context_messages = self.get_message_context(message_id, context_length)
if context_messages:
formatted_context = self.format_messages(context_messages, message_id)
context_list.append(formatted_context)
return context_list, user_nickname
if __name__ == "__main__":
# 测试代码
analyzer = MessageAnalyzer()
@@ -145,7 +146,7 @@ if __name__ == "__main__":
print("-" * 50)
# 获取5条消息每条消息前后各3条上下文
contexts, nickname = analyzer.get_user_random_contexts(test_qq, num_messages=5, context_length=3)
print(f"用户昵称: {nickname}\n")
# 打印每个上下文
for i, context in enumerate(contexts, 1):