ruff
This commit is contained in:
@@ -7,7 +7,6 @@ from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from src.individuality.individuality import individuality
|
||||
import json
|
||||
from json_repair import repair_json
|
||||
from datetime import datetime
|
||||
@@ -90,9 +89,7 @@ class RelationshipManager:
|
||||
return is_known
|
||||
|
||||
@staticmethod
|
||||
async def first_knowing_some_one(
|
||||
platform: str, user_id: str, user_nickname: str, user_cardname: str
|
||||
):
|
||||
async def first_knowing_some_one(platform: str, user_id: str, user_nickname: str, user_cardname: str):
|
||||
"""判断是否认识某人"""
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
# 生成唯一的 person_name
|
||||
@@ -112,7 +109,7 @@ class RelationshipManager:
|
||||
)
|
||||
# 尝试生成更好的名字
|
||||
# await person_info_manager.qv_person_name(
|
||||
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
|
||||
# person_id=person_id, user_nickname=user_nickname, user_cardname=user_cardname, user_avatar=user_avatar
|
||||
# )
|
||||
|
||||
async def build_relationship_info(self, person, is_id: bool = False) -> str:
|
||||
@@ -124,26 +121,24 @@ class RelationshipManager:
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
if not person_name or person_name == "none":
|
||||
return ""
|
||||
impression = await person_info_manager.get_value(person_id, "impression")
|
||||
# impression = await person_info_manager.get_value(person_id, "impression")
|
||||
points = await person_info_manager.get_value(person_id, "points") or []
|
||||
|
||||
|
||||
if isinstance(points, str):
|
||||
try:
|
||||
points = ast.literal_eval(points)
|
||||
except (SyntaxError, ValueError):
|
||||
points = []
|
||||
|
||||
|
||||
random_points = random.sample(points, min(5, len(points))) if points else []
|
||||
|
||||
|
||||
nickname_str = await person_info_manager.get_value(person_id, "nickname")
|
||||
platform = await person_info_manager.get_value(person_id, "platform")
|
||||
relation_prompt = f"'{person_name}' ,ta在{platform}上的昵称是{nickname_str}。"
|
||||
|
||||
|
||||
# if impression:
|
||||
# relation_prompt += f"你对ta的印象是:{impression}。"
|
||||
# relation_prompt += f"你对ta的印象是:{impression}。"
|
||||
|
||||
|
||||
if random_points:
|
||||
for point in random_points:
|
||||
# print(f"point: {point}")
|
||||
@@ -151,13 +146,12 @@ class RelationshipManager:
|
||||
# print(f"point[0]: {point[0]}")
|
||||
point_str = f"时间:{point[2]}。内容:{point[0]}"
|
||||
relation_prompt += f"你记得{person_name}最近的点是:{point_str}。"
|
||||
|
||||
|
||||
|
||||
return relation_prompt
|
||||
|
||||
async def _update_list_field(self, person_id: str, field_name: str, new_items: list) -> None:
|
||||
"""更新列表类型的字段,将新项目添加到现有列表中
|
||||
|
||||
|
||||
Args:
|
||||
person_id: 用户ID
|
||||
field_name: 字段名称
|
||||
@@ -179,21 +173,21 @@ class RelationshipManager:
|
||||
"""
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
nickname = await person_info_manager.get_value(person_id, "nickname")
|
||||
|
||||
|
||||
alias_str = ", ".join(global_config.bot.alias_names)
|
||||
personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
# personality_block = individuality.get_personality_prompt(x_person=2, level=2)
|
||||
# identity_block = individuality.get_identity_prompt(x_person=2, level=2)
|
||||
|
||||
user_messages = bot_engaged_messages
|
||||
|
||||
|
||||
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
# 匿名化消息
|
||||
# 创建用户名称映射
|
||||
name_mapping = {}
|
||||
current_user = "A"
|
||||
user_count = 1
|
||||
|
||||
|
||||
# 遍历消息,构建映射
|
||||
for msg in user_messages:
|
||||
await person_info_manager.get_or_create_person(
|
||||
@@ -206,37 +200,31 @@ class RelationshipManager:
|
||||
replace_platform = msg.get("chat_info_platform")
|
||||
replace_person_id = person_info_manager.get_person_id(replace_platform, replace_user_id)
|
||||
replace_person_name = await person_info_manager.get_value(replace_person_id, "person_name")
|
||||
|
||||
|
||||
# 跳过机器人自己
|
||||
if replace_user_id == global_config.bot.qq_account:
|
||||
name_mapping[f"{global_config.bot.nickname}"] = f"{global_config.bot.nickname}"
|
||||
continue
|
||||
|
||||
|
||||
# 跳过目标用户
|
||||
if replace_person_name == person_name:
|
||||
name_mapping[replace_person_name] = f"{person_name}"
|
||||
continue
|
||||
|
||||
|
||||
# 其他用户映射
|
||||
if replace_person_name not in name_mapping:
|
||||
if current_user > 'Z':
|
||||
current_user = 'A'
|
||||
if current_user > "Z":
|
||||
current_user = "A"
|
||||
user_count += 1
|
||||
name_mapping[replace_person_name] = f"用户{current_user}{user_count if user_count > 1 else ''}"
|
||||
current_user = chr(ord(current_user) + 1)
|
||||
|
||||
|
||||
|
||||
|
||||
readable_messages = self.build_focus_readable_messages(
|
||||
messages=user_messages,
|
||||
target_person_id=person_id
|
||||
)
|
||||
|
||||
readable_messages = self.build_focus_readable_messages(messages=user_messages, target_person_id=person_id)
|
||||
|
||||
for original_name, mapped_name in name_mapping.items():
|
||||
# print(f"original_name: {original_name}, mapped_name: {mapped_name}")
|
||||
readable_messages = readable_messages.replace(f"{original_name}", f"{mapped_name}")
|
||||
|
||||
|
||||
prompt = f"""
|
||||
你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。
|
||||
请不要混淆你自己和{global_config.bot.nickname}和{person_name}。
|
||||
@@ -271,22 +259,22 @@ class RelationshipManager:
|
||||
"weight": 0
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
# 调用LLM生成印象
|
||||
points, _ = await self.relationship_llm.generate_response_async(prompt=prompt)
|
||||
points = points.strip()
|
||||
|
||||
|
||||
# 还原用户名称
|
||||
for original_name, mapped_name in name_mapping.items():
|
||||
points = points.replace(mapped_name, original_name)
|
||||
|
||||
|
||||
# logger.info(f"prompt: {prompt}")
|
||||
# logger.info(f"points: {points}")
|
||||
|
||||
|
||||
if not points:
|
||||
logger.warning(f"未能从LLM获取 {person_name} 的新印象")
|
||||
return
|
||||
|
||||
|
||||
# 解析JSON并转换为元组列表
|
||||
try:
|
||||
points = repair_json(points)
|
||||
@@ -307,7 +295,7 @@ class RelationshipManager:
|
||||
except (KeyError, TypeError) as e:
|
||||
logger.error(f"处理points数据失败: {e}, points: {points}")
|
||||
return
|
||||
|
||||
|
||||
current_points = await person_info_manager.get_value(person_id, "points") or []
|
||||
if isinstance(current_points, str):
|
||||
try:
|
||||
@@ -318,7 +306,9 @@ class RelationshipManager:
|
||||
elif not isinstance(current_points, list):
|
||||
current_points = []
|
||||
current_points.extend(points_list)
|
||||
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None))
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
|
||||
)
|
||||
|
||||
# 将新记录添加到现有记录中
|
||||
if isinstance(current_points, list):
|
||||
@@ -326,14 +316,14 @@ class RelationshipManager:
|
||||
for new_point in points_list:
|
||||
similar_points = []
|
||||
similar_indices = []
|
||||
|
||||
|
||||
# 在现有points中查找相似的点
|
||||
for i, existing_point in enumerate(current_points):
|
||||
# 使用组合的相似度检查方法
|
||||
if self.check_similarity(new_point[0], existing_point[0]):
|
||||
similar_points.append(existing_point)
|
||||
similar_indices.append(i)
|
||||
|
||||
|
||||
if similar_points:
|
||||
# 合并相似的点
|
||||
all_points = [new_point] + similar_points
|
||||
@@ -343,14 +333,14 @@ class RelationshipManager:
|
||||
total_weight = sum(p[1] for p in all_points)
|
||||
# 使用最长的描述
|
||||
longest_desc = max(all_points, key=lambda x: len(x[0]))[0]
|
||||
|
||||
|
||||
# 创建合并后的点
|
||||
merged_point = (longest_desc, total_weight, latest_time)
|
||||
|
||||
|
||||
# 从现有points中移除已合并的点
|
||||
for idx in sorted(similar_indices, reverse=True):
|
||||
current_points.pop(idx)
|
||||
|
||||
|
||||
# 添加合并后的点
|
||||
current_points.append(merged_point)
|
||||
else:
|
||||
@@ -359,7 +349,7 @@ class RelationshipManager:
|
||||
else:
|
||||
current_points = points_list
|
||||
|
||||
# 如果points超过10条,按权重随机选择多余的条目移动到forgotten_points
|
||||
# 如果points超过10条,按权重随机选择多余的条目移动到forgotten_points
|
||||
if len(current_points) > 10:
|
||||
# 获取现有forgotten_points
|
||||
forgotten_points = await person_info_manager.get_value(person_id, "forgotten_points") or []
|
||||
@@ -371,29 +361,29 @@ class RelationshipManager:
|
||||
forgotten_points = []
|
||||
elif not isinstance(forgotten_points, list):
|
||||
forgotten_points = []
|
||||
|
||||
|
||||
# 计算当前时间
|
||||
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
# 计算每个点的最终权重(原始权重 * 时间权重)
|
||||
weighted_points = []
|
||||
for point in current_points:
|
||||
time_weight = self.calculate_time_weight(point[2], current_time)
|
||||
final_weight = point[1] * time_weight
|
||||
weighted_points.append((point, final_weight))
|
||||
|
||||
|
||||
# 计算总权重
|
||||
total_weight = sum(w for _, w in weighted_points)
|
||||
|
||||
|
||||
# 按权重随机选择要保留的点
|
||||
remaining_points = []
|
||||
points_to_move = []
|
||||
|
||||
|
||||
# 对每个点进行随机选择
|
||||
for point, weight in weighted_points:
|
||||
# 计算保留概率(权重越高越可能保留)
|
||||
keep_probability = weight / total_weight
|
||||
|
||||
|
||||
if len(remaining_points) < 10:
|
||||
# 如果还没达到30条,直接保留
|
||||
remaining_points.append(point)
|
||||
@@ -407,28 +397,26 @@ class RelationshipManager:
|
||||
else:
|
||||
# 不保留这个点
|
||||
points_to_move.append(point)
|
||||
|
||||
|
||||
# 更新points和forgotten_points
|
||||
current_points = remaining_points
|
||||
forgotten_points.extend(points_to_move)
|
||||
|
||||
|
||||
# 检查forgotten_points是否达到5条
|
||||
if len(forgotten_points) >= 10:
|
||||
# 构建压缩总结提示词
|
||||
alias_str = ", ".join(global_config.bot.alias_names)
|
||||
|
||||
|
||||
# 按时间排序forgotten_points
|
||||
forgotten_points.sort(key=lambda x: x[2])
|
||||
|
||||
|
||||
# 构建points文本
|
||||
points_text = "\n".join([
|
||||
f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}"
|
||||
for point in forgotten_points
|
||||
])
|
||||
|
||||
|
||||
points_text = "\n".join(
|
||||
[f"时间:{point[2]}\n权重:{point[1]}\n内容:{point[0]}" for point in forgotten_points]
|
||||
)
|
||||
|
||||
impression = await person_info_manager.get_value(person_id, "impression") or ""
|
||||
|
||||
|
||||
compress_prompt = f"""
|
||||
你的名字是{global_config.bot.nickname},{global_config.bot.nickname}的别名是{alias_str}。
|
||||
请不要混淆你自己和{global_config.bot.nickname}和{person_name}。
|
||||
@@ -449,88 +437,85 @@ class RelationshipManager:
|
||||
"""
|
||||
# 调用LLM生成压缩总结
|
||||
compressed_summary, _ = await self.relationship_llm.generate_response_async(prompt=compress_prompt)
|
||||
|
||||
|
||||
current_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
||||
compressed_summary = f"截至{current_time},你对{person_name}的了解:{compressed_summary}"
|
||||
|
||||
|
||||
await person_info_manager.update_one_field(person_id, "impression", compressed_summary)
|
||||
|
||||
|
||||
forgotten_points = []
|
||||
|
||||
|
||||
# 这句代码的作用是:将更新后的 forgotten_points(遗忘的记忆点)列表,序列化为 JSON 字符串后,写回到数据库中的 forgotten_points 字段
|
||||
await person_info_manager.update_one_field(person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None))
|
||||
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "forgotten_points", json.dumps(forgotten_points, ensure_ascii=False, indent=None)
|
||||
)
|
||||
|
||||
# 更新数据库
|
||||
await person_info_manager.update_one_field(person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None))
|
||||
await person_info_manager.update_one_field(
|
||||
person_id, "points", json.dumps(current_points, ensure_ascii=False, indent=None)
|
||||
)
|
||||
know_times = await person_info_manager.get_value(person_id, "know_times") or 0
|
||||
await person_info_manager.update_one_field(person_id, "know_times", know_times + 1)
|
||||
await person_info_manager.update_one_field(person_id, "last_know", timestamp)
|
||||
|
||||
|
||||
logger.info(f"印象更新完成 for {person_name}")
|
||||
|
||||
|
||||
|
||||
|
||||
def build_focus_readable_messages(self, messages: list, target_person_id: str = None) -> str:
|
||||
"""格式化消息,只保留目标用户和bot消息附近的内容"""
|
||||
# 找到目标用户和bot的消息索引
|
||||
target_indices = []
|
||||
for i, msg in enumerate(messages):
|
||||
user_id = msg.get("user_id")
|
||||
platform = msg.get("chat_info_platform")
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
if person_id == target_person_id:
|
||||
target_indices.append(i)
|
||||
|
||||
if not target_indices:
|
||||
return ""
|
||||
|
||||
# 获取需要保留的消息索引
|
||||
keep_indices = set()
|
||||
for idx in target_indices:
|
||||
# 获取前后5条消息的索引
|
||||
start_idx = max(0, idx - 5)
|
||||
end_idx = min(len(messages), idx + 6)
|
||||
keep_indices.update(range(start_idx, end_idx))
|
||||
|
||||
print(keep_indices)
|
||||
|
||||
# 将索引排序
|
||||
keep_indices = sorted(list(keep_indices))
|
||||
|
||||
# 按顺序构建消息组
|
||||
message_groups = []
|
||||
current_group = []
|
||||
|
||||
for i in range(len(messages)):
|
||||
if i in keep_indices:
|
||||
current_group.append(messages[i])
|
||||
elif current_group:
|
||||
# 如果当前组不为空,且遇到不保留的消息,则结束当前组
|
||||
if current_group:
|
||||
message_groups.append(current_group)
|
||||
current_group = []
|
||||
|
||||
# 添加最后一组
|
||||
if current_group:
|
||||
message_groups.append(current_group)
|
||||
|
||||
# 构建最终的消息文本
|
||||
result = []
|
||||
for i, group in enumerate(message_groups):
|
||||
if i > 0:
|
||||
result.append("...")
|
||||
group_text = build_readable_messages(
|
||||
messages=group,
|
||||
replace_bot_name=True,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
truncate=False
|
||||
)
|
||||
result.append(group_text)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
"""格式化消息,只保留目标用户和bot消息附近的内容"""
|
||||
# 找到目标用户和bot的消息索引
|
||||
target_indices = []
|
||||
for i, msg in enumerate(messages):
|
||||
user_id = msg.get("user_id")
|
||||
platform = msg.get("chat_info_platform")
|
||||
person_id = person_info_manager.get_person_id(platform, user_id)
|
||||
if person_id == target_person_id:
|
||||
target_indices.append(i)
|
||||
|
||||
if not target_indices:
|
||||
return ""
|
||||
|
||||
# 获取需要保留的消息索引
|
||||
keep_indices = set()
|
||||
for idx in target_indices:
|
||||
# 获取前后5条消息的索引
|
||||
start_idx = max(0, idx - 5)
|
||||
end_idx = min(len(messages), idx + 6)
|
||||
keep_indices.update(range(start_idx, end_idx))
|
||||
|
||||
print(keep_indices)
|
||||
|
||||
# 将索引排序
|
||||
keep_indices = sorted(list(keep_indices))
|
||||
|
||||
# 按顺序构建消息组
|
||||
message_groups = []
|
||||
current_group = []
|
||||
|
||||
for i in range(len(messages)):
|
||||
if i in keep_indices:
|
||||
current_group.append(messages[i])
|
||||
elif current_group:
|
||||
# 如果当前组不为空,且遇到不保留的消息,则结束当前组
|
||||
if current_group:
|
||||
message_groups.append(current_group)
|
||||
current_group = []
|
||||
|
||||
# 添加最后一组
|
||||
if current_group:
|
||||
message_groups.append(current_group)
|
||||
|
||||
# 构建最终的消息文本
|
||||
result = []
|
||||
for i, group in enumerate(message_groups):
|
||||
if i > 0:
|
||||
result.append("...")
|
||||
group_text = build_readable_messages(
|
||||
messages=group, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=False
|
||||
)
|
||||
result.append(group_text)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
def calculate_time_weight(self, point_time: str, current_time: str) -> float:
|
||||
"""计算基于时间的权重系数"""
|
||||
try:
|
||||
@@ -538,7 +523,7 @@ class RelationshipManager:
|
||||
current_timestamp = datetime.strptime(current_time, "%Y-%m-%d %H:%M:%S")
|
||||
time_diff = current_timestamp - point_timestamp
|
||||
hours_diff = time_diff.total_seconds() / 3600
|
||||
|
||||
|
||||
if hours_diff <= 1: # 1小时内
|
||||
return 1.0
|
||||
elif hours_diff <= 24: # 1-24小时
|
||||
@@ -564,18 +549,18 @@ class RelationshipManager:
|
||||
s1 = " ".join(str(x) for x in s1)
|
||||
if isinstance(s2, list):
|
||||
s2 = " ".join(str(x) for x in s2)
|
||||
|
||||
|
||||
# 转换为字符串类型
|
||||
s1 = str(s1)
|
||||
s2 = str(s2)
|
||||
|
||||
|
||||
# 1. 使用 jieba 进行分词
|
||||
s1_words = " ".join(jieba.cut(s1))
|
||||
s2_words = " ".join(jieba.cut(s2))
|
||||
|
||||
|
||||
# 2. 将两句话放入一个列表中
|
||||
corpus = [s1_words, s2_words]
|
||||
|
||||
|
||||
# 3. 创建 TF-IDF 向量化器并进行计算
|
||||
try:
|
||||
vectorizer = TfidfVectorizer()
|
||||
@@ -586,7 +571,7 @@ class RelationshipManager:
|
||||
|
||||
# 4. 计算余弦相似度
|
||||
similarity_matrix = cosine_similarity(tfidf_matrix)
|
||||
|
||||
|
||||
# 返回 s1 和 s2 的相似度
|
||||
return similarity_matrix[0, 1]
|
||||
|
||||
@@ -599,20 +584,20 @@ class RelationshipManager:
|
||||
def check_similarity(self, text1, text2, tfidf_threshold=0.5, seq_threshold=0.6):
|
||||
"""
|
||||
使用两种方法检查文本相似度,只要其中一种方法达到阈值就认为是相似的。
|
||||
|
||||
|
||||
Args:
|
||||
text1: 第一个文本
|
||||
text2: 第二个文本
|
||||
tfidf_threshold: TF-IDF相似度阈值
|
||||
seq_threshold: SequenceMatcher相似度阈值
|
||||
|
||||
|
||||
Returns:
|
||||
bool: 如果任一方法达到阈值则返回True
|
||||
"""
|
||||
# 计算两种相似度
|
||||
tfidf_sim = self.tfidf_similarity(text1, text2)
|
||||
seq_sim = self.sequence_similarity(text1, text2)
|
||||
|
||||
|
||||
# 只要其中一种方法达到阈值就认为是相似的
|
||||
return tfidf_sim > tfidf_threshold or seq_sim > seq_threshold
|
||||
|
||||
|
||||
Reference in New Issue
Block a user