mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
return matches from dataloader
This commit is contained in:
@@ -356,7 +356,8 @@ def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=
|
||||
|
||||
### RetinaNet
|
||||
|
||||
def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue, X:Tensor, Y_boxes:Tensor, Y_labels:Tensor, anchors:np.ndarray, seed:Optional[int]=None):
|
||||
def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue, X:Tensor,
|
||||
Y_boxes:Tensor, Y_labels:Tensor, matches:Tensor, anchors:np.ndarray, seed:Optional[int]=None):
|
||||
from extra.datasets.openimages import image_load, prepare_target, random_horizontal_flip, resize
|
||||
from examples.mlperf.helpers import box_iou, find_matches
|
||||
import torch
|
||||
@@ -377,12 +378,13 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
|
||||
img, tgt = random_horizontal_flip(img, tgt)
|
||||
img, tgt, _ = resize(img, tgt=tgt)
|
||||
match_quality_matrix = box_iou(tgt["boxes"], anchors)
|
||||
matches = find_matches(match_quality_matrix, allow_low_quality_matches=True)
|
||||
matches = np.clip(matches, 0, None)
|
||||
match_idxs = find_matches(match_quality_matrix, allow_low_quality_matches=True)
|
||||
match_idxs = np.clip(match_idxs, 0, None)
|
||||
boxes, labels = tgt["boxes"][matches], tgt["labels"][matches]
|
||||
|
||||
Y_boxes[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = boxes.tobytes()
|
||||
Y_labels[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = labels.tobytes()
|
||||
matches[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = match_idxs.tobytes()
|
||||
|
||||
X[idx].contiguous().realize().lazydata.realized.as_buffer(force_zero_copy=True)[:] = img.tobytes()
|
||||
|
||||
@@ -410,6 +412,7 @@ def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, b
|
||||
shm_x, X = _setup_shared_mem("retinanet_x", (batch_size * batch_count, 800, 800, 3), dtypes.float32)
|
||||
shm_y_boxes, Y_boxes = _setup_shared_mem("retinanet_y_boxes", (batch_size * batch_count, 120087, 4), dtypes.float32)
|
||||
shm_y_labels, Y_labels = _setup_shared_mem("retinanet_y_labels", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
shm_matches, matches = _setup_shared_mem("retinanet_matches", (batch_size * batch_count, 120087), dtypes.int64)
|
||||
|
||||
shutdown = False
|
||||
class Cookie:
|
||||
@@ -429,7 +432,7 @@ def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, b
|
||||
|
||||
try:
|
||||
for _ in range(cpu_count()):
|
||||
proc = Process(target=load_retinanet_data, args=(base_dir, val, queue_in, queue_out, X, Y_boxes, Y_labels, anchors, seed))
|
||||
proc = Process(target=load_retinanet_data, args=(base_dir, val, queue_in, queue_out, X, Y_boxes, Y_labels, matches, anchors, seed))
|
||||
proc.daemon = True
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
@@ -448,7 +451,11 @@ def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, b
|
||||
if val:
|
||||
yield X[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
|
||||
else:
|
||||
yield X[bc * batch_size:(bc + 1) * batch_size], Y_boxes[bc * batch_size:(bc + 1) * batch_size], Y_labels[bc * batch_size:(bc + 1) * batch_size], Cookie(bc)
|
||||
yield (X[bc * batch_size:(bc + 1) * batch_size],
|
||||
Y_boxes[bc * batch_size:(bc + 1) * batch_size],
|
||||
Y_labels[bc * batch_size:(bc + 1) * batch_size],
|
||||
matches[bc * batch_size:(bc + 1) * batch_size],
|
||||
Cookie(bc))
|
||||
finally:
|
||||
shutdown = True
|
||||
|
||||
@@ -465,10 +472,12 @@ def batch_load_retinanet(dataset, val:bool, anchors:np.ndarray, base_dir:Path, b
|
||||
shm_x.close()
|
||||
shm_y_boxes.close()
|
||||
shm_y_labels.close()
|
||||
shm_matches.close()
|
||||
try:
|
||||
shm_x.unlink()
|
||||
shm_y_boxes.unlink()
|
||||
shm_y_labels.unlink()
|
||||
shm_matches.unlink()
|
||||
except FileNotFoundError:
|
||||
# happens with BENCHMARK set
|
||||
pass
|
||||
@@ -499,8 +508,8 @@ if __name__ == "__main__":
|
||||
dataset = COCO(download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "validation" if val else "train"))
|
||||
anchors = np.ones((120087, 4))
|
||||
with tqdm(total=len(dataset.imgs.keys())) as pbar:
|
||||
for x, _, _, _ in batch_load_retinanet(dataset, val, anchors, base_dir):
|
||||
pbar.update(x.shape[0])
|
||||
for x in batch_load_retinanet(dataset, val, anchors, base_dir):
|
||||
pbar.update(x[0].shape[0])
|
||||
|
||||
load_fn_name = f"load_{getenv('MODEL', 'resnet')}"
|
||||
if load_fn_name in globals():
|
||||
|
||||
@@ -367,8 +367,8 @@ def train_retinanet():
|
||||
layer.requires_grad = False
|
||||
|
||||
def _data_get(it):
|
||||
x, y_boxes, y_labels, cookie = next(it)
|
||||
return x.shard(GPUS, axis=0).realize(), y_boxes, y_labels, cookie
|
||||
x, y_boxes, y_labels, matches, cookie = next(it)
|
||||
return x.shard(GPUS, axis=0).realize(), y_boxes, y_labels, matches.shard(GPUS, axis=0), cookie
|
||||
|
||||
def _create_lr_scheduler(optim, start_iter, warmup_iters, warmup_factor):
|
||||
# TODO: refactor this a bit more so we don't have to recreate it, unlike what MLPerf script is doing
|
||||
@@ -379,7 +379,7 @@ def train_retinanet():
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
return LambdaLR(optim, _lr_lambda)
|
||||
|
||||
def _train_step(model, optim, lr_scheduler, x):
|
||||
def _train_step(model, optim, lr_scheduler, x, matches):
|
||||
optim.zero_grad()
|
||||
|
||||
y_hat = model(normalize(x, GPUS))
|
||||
@@ -432,8 +432,8 @@ def train_retinanet():
|
||||
st = time.perf_counter()
|
||||
|
||||
while proc is not None:
|
||||
x, y_boxes, y_labels, proc = proc
|
||||
_train_step(model, optim, lr_scheduler, x) # TODO: enable once full model has been integrated
|
||||
x, y_boxes, y_labels, matches, proc = proc
|
||||
_train_step(model, optim, lr_scheduler, x, matches) # TODO: enable once full model has been integrated
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user