diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index 905330974c..e9df7f5ee2 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -380,7 +380,7 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue match_quality_matrix = box_iou(tgt["boxes"], anchors) match_idxs = find_matches(match_quality_matrix, allow_low_quality_matches=True) match_idxs = np.clip(match_idxs, 0, None) - boxes, labels = tgt["boxes"][matches], tgt["labels"][matches] + boxes, labels = tgt["boxes"][match_idxs], tgt["labels"][match_idxs] Y_boxes[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = boxes.tobytes() Y_labels[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = labels.tobytes()