fix issue with realized on dataloader

This commit is contained in:
Francis Lata
2025-01-31 08:31:25 -08:00
parent 744753f7a4
commit 80fa9dd731

View File

@@ -363,19 +363,19 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
img = image_load(base_dir, img["subset"], img["file_name"])
tgt = prepare_target(ann, img_id, img.size[::-1])
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
if val:
assert img_ids is not None, "img_ids have to be passed in"
img = resize(img)[0]
img_ids[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = np.array(img_id).tobytes()
img_ids[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = np.array(img_id).tobytes()
else:
assert boxes is not None and labels is not None and matches is not None and anchors is not None, "boxes, labels, matches and anchors have to be passed in"
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
img, tgt = random_horizontal_flip(img, tgt)
img, tgt, _ = resize(img, tgt=tgt)
match_quality_matrix = box_iou(tgt["boxes"], (anchor := np.concatenate(generate_anchors((800, 800)))))
@@ -383,12 +383,12 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
clipped_match_idxs = np.clip(match_idxs, 0, None)
clipped_boxes, clipped_labels = tgt["boxes"][clipped_match_idxs], tgt["labels"][clipped_match_idxs]
boxes[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = clipped_boxes.tobytes()
labels[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = clipped_labels.tobytes()
matches[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes()
anchors[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = anchor.tobytes()
boxes[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_boxes.tobytes()
labels[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = clipped_labels.tobytes()
matches[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes()
anchors[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = anchor.tobytes()
imgs[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
imgs[idx].contiguous().realize().lazydata.base.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
queue_out.put(idx)
queue_out.put(None)