feat:每个消息段有10%概率丢弃
This commit is contained in:
@@ -2,6 +2,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import random
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -415,11 +416,28 @@ class RelationshipBuilder:
|
|||||||
|
|
||||||
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, any]]):
|
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, any]]):
|
||||||
"""基于消息段更新用户印象"""
|
"""基于消息段更新用户印象"""
|
||||||
logger.debug(f"开始为 {person_id} 基于 {len(segments)} 个消息段更新印象")
|
original_segment_count = len(segments)
|
||||||
|
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")
|
||||||
try:
|
try:
|
||||||
|
# 筛选要处理的消息段,每个消息段有10%的概率被丢弃
|
||||||
|
segments_to_process = [s for s in segments if random.random() >= 0.1]
|
||||||
|
|
||||||
|
# 如果所有消息段都被丢弃,但原来有消息段,则至少保留一个(最新的)
|
||||||
|
if not segments_to_process and segments:
|
||||||
|
segments.sort(key=lambda x: x["end_time"], reverse=True)
|
||||||
|
segments_to_process.append(segments[0])
|
||||||
|
logger.debug(f"随机丢弃了所有消息段,强制保留最新的一个以进行处理。")
|
||||||
|
|
||||||
|
dropped_count = original_segment_count - len(segments_to_process)
|
||||||
|
if dropped_count > 0:
|
||||||
|
logger.info(f"为 {person_id} 随机丢弃了 {dropped_count} / {original_segment_count} 个消息段")
|
||||||
|
|
||||||
processed_messages = []
|
processed_messages = []
|
||||||
|
|
||||||
for i, segment in enumerate(segments):
|
# 对筛选后的消息段进行排序,确保时间顺序
|
||||||
|
segments_to_process.sort(key=lambda x: x["start_time"])
|
||||||
|
|
||||||
|
for segment in segments_to_process:
|
||||||
start_time = segment["start_time"]
|
start_time = segment["start_time"]
|
||||||
end_time = segment["end_time"]
|
end_time = segment["end_time"]
|
||||||
start_date = time.strftime("%Y-%m-%d %H:%M", time.localtime(start_time))
|
start_date = time.strftime("%Y-%m-%d %H:%M", time.localtime(start_time))
|
||||||
@@ -427,12 +445,12 @@ class RelationshipBuilder:
|
|||||||
# 获取该段的消息(包含边界)
|
# 获取该段的消息(包含边界)
|
||||||
segment_messages = get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time)
|
segment_messages = get_raw_msg_by_timestamp_with_chat_inclusive(self.chat_id, start_time, end_time)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"消息段 {i + 1}: {start_date} - {time.strftime('%Y-%m-%d %H:%M', time.localtime(end_time))}, 消息数: {len(segment_messages)}"
|
f"消息段: {start_date} - {time.strftime('%Y-%m-%d %H:%M', time.localtime(end_time))}, 消息数: {len(segment_messages)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if segment_messages:
|
if segment_messages:
|
||||||
# 如果不是第一个消息段,在消息列表前添加间隔标识
|
# 如果 processed_messages 不为空,说明这不是第一个被处理的消息段,在消息列表前添加间隔标识
|
||||||
if i > 0:
|
if processed_messages:
|
||||||
# 创建一个特殊的间隔消息
|
# 创建一个特殊的间隔消息
|
||||||
gap_message = {
|
gap_message = {
|
||||||
"time": start_time - 0.1, # 稍微早于段开始时间
|
"time": start_time - 0.1, # 稍微早于段开始时间
|
||||||
|
|||||||
Reference in New Issue
Block a user