mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cleanup train and validation dataloader
This commit is contained in:
@@ -351,8 +351,8 @@ def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=
|
||||
### RetinaNet
|
||||
|
||||
def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue,
|
||||
imgs:Tensor, boxes:Tensor, labels:Tensor, matches:Tensor,
|
||||
anchors:Tensor, img_ids:Tensor, seed:Optional[int] = None):
|
||||
imgs:Tensor, boxes:Tensor, labels:Tensor, matches:Optional[Tensor] = None,
|
||||
anchors:Optional[Tensor] = None, seed:Optional[int] = None):
|
||||
from extra.datasets.openimages import image_load, prepare_target, random_horizontal_flip, resize
|
||||
from examples.mlperf.helpers import box_iou, find_matches, generate_anchors
|
||||
import torch
|
||||
@@ -391,8 +391,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
def _enqueue_batch(bc):
|
||||
for idx in range(bc * batch_size, (bc+1) * batch_size):
|
||||
img = dataset.loadImgs(next(dataset_iter))[0]
|
||||
ann = dataset.loadAnns(dataset.getAnnIds((img_id:=img["id"])))
|
||||
img_ids[idx] = img_id
|
||||
ann = dataset.loadAnns(dataset.getAnnIds(img["id"]))
|
||||
queue_in.put((idx, img, ann))
|
||||
|
||||
def _setup_shared_mem(shm_name:str, size:Tuple[int, ...], dtype:dtypes) -> Tuple[shared_memory.SharedMemory, Tensor]:
|
||||
@@ -408,12 +407,15 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
procs, data_out_count = [], [0] * batch_count
|
||||
|
||||
shm_imgs, imgs = _setup_shared_mem("retinanet_imgs", (batch_size * batch_count, 800, 800, 3), dtypes.float32)
|
||||
shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
|
||||
shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
shm_anchors, anchors = _setup_shared_mem("retinanet_anchors", (batch_size * batch_count, 120087, 4), dtypes.float64)
|
||||
|
||||
img_ids = [None] * (batch_size * batch_count)
|
||||
if val:
|
||||
matches, anchors = None, None
|
||||
boxes, labels = [None] * (batch_size * batch_count), [None] * (batch_size * batch_count)
|
||||
else:
|
||||
shm_boxes, boxes = _setup_shared_mem("retinanet_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
|
||||
shm_labels, labels = _setup_shared_mem("retinanet_labels", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
shm_anchors, anchors = _setup_shared_mem("retinanet_anchors", (batch_size * batch_count, 120087, 4), dtypes.float64)
|
||||
|
||||
shutdown = False
|
||||
class Cookie:
|
||||
@@ -435,8 +437,8 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
for _ in range(cpu_count()):
|
||||
proc = Process(
|
||||
target=load_retinanet_data,
|
||||
args=(base_dir, val, queue_in, queue_out, imgs, boxes, labels, matches, anchors, img_ids),
|
||||
kwargs={"seed": seed}
|
||||
args=(base_dir, val, queue_in, queue_out, imgs, boxes, labels),
|
||||
kwargs={"matches": matches, "anchors": anchors, "seed": seed}
|
||||
)
|
||||
proc.daemon = True
|
||||
proc.start()
|
||||
@@ -455,7 +457,7 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
|
||||
if val:
|
||||
yield (imgs[bc * batch_size:(bc + 1) * batch_size],
|
||||
img_ids[bc * batch_size:(bc + 1) * batch_size],
|
||||
image_ids[bc * batch_size:(bc + 1) * batch_size],
|
||||
Cookie(bc))
|
||||
else:
|
||||
yield (imgs[bc * batch_size:(bc + 1) * batch_size],
|
||||
@@ -478,17 +480,21 @@ def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, sh
|
||||
for proc in procs: proc.join()
|
||||
|
||||
shm_imgs.close()
|
||||
shm_boxes.close()
|
||||
shm_labels.close()
|
||||
shm_matches.close()
|
||||
shm_anchors.close()
|
||||
|
||||
if not val:
|
||||
shm_boxes.close()
|
||||
shm_labels.close()
|
||||
shm_matches.close()
|
||||
shm_anchors.close()
|
||||
|
||||
try:
|
||||
shm_imgs.unlink()
|
||||
shm_boxes.unlink()
|
||||
shm_labels.unlink()
|
||||
shm_matches.unlink()
|
||||
shm_anchors.unlink()
|
||||
|
||||
if not val:
|
||||
shm_boxes.unlink()
|
||||
shm_labels.unlink()
|
||||
shm_matches.unlink()
|
||||
shm_anchors.unlink()
|
||||
except FileNotFoundError:
|
||||
# happens with BENCHMARK set
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user