cleanup train and validation dataloader

This commit is contained in:
Francis Lata
2025-01-31 16:59:37 -08:00
parent 6d70035c22
commit 811893a3bd

View File

@@ -351,8 +351,8 @@ def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=
### RetinaNet
def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue,
imgs:Tensor, boxes:Tensor, labels:Tensor, matches:Tensor,
anchors:Tensor, img_ids:Tensor, seed:Optional[int] = None):
imgs:Tensor, boxes:Tensor, labels:Tensor, matches:Optional[Tensor] = None,
anchors:Optional[Tensor] = None, seed:Optional[int] = None):
from extra.datasets.openimages import image_load, prepare_target, random_horizontal_flip, resize
from examples.mlperf.helpers import box_iou, find_matches, generate_anchors
import torch
@@ -391,8 +391,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
def _enqueue_batch(bc):
for idx in range(bc * batch_size, (bc+1) * batch_size):
img = dataset.loadImgs(next(dataset_iter))[0]
ann = dataset.loadAnns(dataset.getAnnIds((img_id:=img["id"])))
img_ids[idx] = img_id
ann = dataset.loadAnns(dataset.getAnnIds(img["id"]))
queue_in.put((idx, img, ann))
def _setup_shared_mem(shm_name:str, size:Tuple[int, ...], dtype:dtypes) -> Tuple[shared_memory.SharedMemory, Tensor]:
@@ -408,12 +407,15 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
procs, data_out_count = [], [0] * batch_count
shm_imgs, imgs = _setup_shared_mem("retinanet_imgs", (batch_size * batch_count, 800, 800, 3), dtypes.float32)
shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64)
shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64)
shm_anchors, anchors = _setup_shared_mem("retinanet_anchors", (batch_size * batch_count, 120087, 4), dtypes.float64)
img_ids = [None] * (batch_size * batch_count)
if val:
matches, anchors = None, None
boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
else:
shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64)
shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64)
shm_anchors, anchors = _setup_shared_mem("retinanet_anchors", (batch_size * batch_count, 120087, 4), dtypes.float64)
shutdown = False
class Cookie:
@@ -435,8 +437,8 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
for _ in range(cpu_count()):
proc = Process(
target=load_retinanet_data,
args=(base_dir, val, queue_in, queue_out, imgs, boxes, labels, matches, anchors, img_ids),
kwargs={"seed": seed}
args=(base_dir, val, queue_in, queue_out, imgs, boxes, labels),
kwargs={"matches": matches, "anchors": anchors, "seed": seed}
)
proc.daemon = True
proc.start()
@@ -455,7 +457,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
if val:
yield (imgs[bc * batch_size:(bc + 1) * batch_size],
img_ids[bc * batch_size:(bc + 1) * batch_size],
image_ids[bc * batch_size:(bc + 1) * batch_size],
Cookie(bc))
else:
yield (imgs[bc * batch_size:(bc + 1) * batch_size],
@@ -478,17 +480,21 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
for proc in procs: proc.join()
shm_imgs.close()
shm_boxes.close()
shm_labels.close()
shm_matches.close()
shm_anchors.close()
if not val:
shm_boxes.close()
shm_labels.close()
shm_matches.close()
shm_anchors.close()
try:
shm_imgs.unlink()
shm_boxes.unlink()
shm_labels.unlink()
shm_matches.unlink()
shm_anchors.unlink()
if not val:
shm_boxes.unlink()
shm_labels.unlink()
shm_matches.unlink()
shm_anchors.unlink()
except FileNotFoundError:
# happens with BENCHMARK set
pass