RetinaNet MLPerf (#8385)

* add support for a custom BASEDIR for openimages download

* make export step faster

* add focal loss

* update model_eval with new dataloader

* generate_anchors in tinygrad

* update initializers for model

* small cleanup

* revert isin enhancements

* recursively go through backbone layers to freeze them

* add optimizer

* minor cleanup

* start dataloader work with input images

* add first transform for train set

* reuse existing prepare_target

* continue with dataloader implementation

* add dataloader

* separate out KiTS19 dataset test cases

* create mock data samples for test

* add dataloader + test

* cleanup dataloader test and revert shm path

* trim dataloader related code needed from ref

* got dataloader with normalize working

* update image to be float32

* add back normalization and negate it in test

* clean up reference dataset implementation + ruff changes

* add validation set test

* add proper training loop over the training dataset

* add LambdaLR support

* add LR scheduler and the start of training step

* get forward call to model work and setup multi-GPU

* already passed device

* return matches from dataloader

* hotfix for dataloader typo causing some hang

* start some work on classification loss

* update focal loss to support masking

* add missing test and cleanup focal loss

* cleanup unit tests

* remove masking support for sigmoid_focal_loss

* make ClassificationHead loss work

* cleanups + fix dataloader tests

* remove sigmoid when computing loss

* make anchors use Tensors

* simplify anchors batching

* revert anchors to use np

* implement regression loss

* fix regression loss

* cleanup losses

* move BoxCoder to MLPerf helpers

* revert helper changes

* fixes after helper refactor cleanup

* add tests for l1_loss

* start re-enabling training step

* minor cleanup

* add pycocotools to testing dependencies

* make training work

* adjust regression loss to mask after L1 loss is calculated

* reduce img and lbl sizes by half for KiTS19 dataset tests

* Revert "reduce img and lbl sizes by half for KiTS19 dataset tests"

This reverts commit d115b0c664.

* temporarily disable openimages dataset tests to debug CI

* enable openimages dataset test and create samples once

* temporarily disable openimages validation set test

* reenable test and add some debugging to the test

* add boto3 testing dependencies

* add pandas to testing dependencies

* This reverts commit 467704fec6.

* reenable test

* move sample creation to setup

* realize boxcoder's encoding

* add wandb

* fix wandb resuming feature

* move anchors as part of dataloader

* fix dtype for anchor inside dataloader and fix horizontal flip transformation

* add support for BENCHMARK

* set seed

* debug dataset test failuire

* Revert "debug dataset test failuire"

This reverts commit 1b2f9d7f50.

* fix dataloader script

* do not realize when sharding model weights

* setup openimages samples differently

* create the necessary samples per test case

* enable lr scheduler and fix benchmark timing

* add jit to the training loop

* add checkpointing and training resume capabilities

* refactor on training loop and start the work on val looop

* add debug logging for dataloader test

* debug test

* assert boxes again

* update validation dataloader and more cleanups

* fix validation test case

* add multi device support to retinanet eval

* fix issue with realized on dataloader

* remove optional disk tensors in dataloader

* remove verbose debugging on datasets test

* put back parallel testing and remove img_ids Tensor from dataloader

* cleanup train and validation dataloader

* return validation targets in dataloader

* cleanup boxes and labels in dataloader

* fix img_ids repeating its values

* remove unnecessary targets from validation dataloader

* add validation loop to training script

* adjust LR to be the ratio of the batch size

* minor cleanups

* remove frozen layers from optimizer's params

* hyperparameter adjustments and cleanups

* model init, hyperparam, and data preprocessing updates

* no need to return loaded keys for resnet

* fix train script

* update loss calculation for regresionhead and some cleanups

* add JIT reset support

* add nan check during training

* Revert "add nan check during training"

This reverts commit ddf1f0d5dd.

* Revert "Revert "add nan check during training""

This reverts commit b7b2943197.

* some typing cleanups

* update seeding on dataloader and the start of training script

* undo changse

* undo more changes

* more typing fixes

* minor cleanups

* update dataloader seed

* hotfix: log metric and move target metric check outside of CKPT

* check for CKPT when target metric is reached before saving

* add TRAIN_BEAM and EVAL_BEAM

* minor cleanup

* update hyperparams and add support for EVAL_BS

* add green coloring to metric reached statement

* initial work to support f16

* update model initializers to be monkeypatched

* update layers to support float32 weight loading + float16 training

* don't return loss that's scaled

* run eval on benchmark beam

* move BEAM to their respective steps

* update layers to be compatible with fp16

* end BENCHMARK after first eval

* cleanups and adjust learning rate for fp16

* remove duplicated files from test

* revert losses changes

* Revert "revert losses changes"

This reverts commit aebccf93ac.

* go back to old LR

* cast batchnorm to float32

* set new loss scaler default value for float16

* remove LambdaLRScheduler

* remove runner and use dataloader on eval

* fix retinanet eval with new dataloader

* remove unused import

* revert lr_scheduler updates

* use BS=96 with new learning rate

* rename module initializers

* more cleanups on training loop

* remove contig from optim.step

* simplify sum when computing loss
This commit is contained in:
Francis Lata
2025-04-12 22:11:51 -04:00
committed by GitHub
parent 23b67f532c
commit 2793cca9a6
2 changed files with 298 additions and 4 deletions

View File

@@ -68,3 +68,62 @@ class LayerNormBert:
xn = x.cast(dtypes.float32).layernorm(eps=self.eps, axis=self.axis).cast(x.dtype)
if not self.elementwise_affine: return xn
return (xn * self.weight.cast(dtypes.default_float) + self.bias.cast(dtypes.default_float))
class FrozenBatchNorm2dRetinaNet(nn.BatchNorm2d):
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
self.weight = Tensor.ones(sz, dtype=dtypes.float32, requires_grad=False) if affine else None
self.bias = Tensor.zeros(sz, dtype=dtypes.float32, requires_grad=False) if affine else None
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, dtype=dtypes.float32, requires_grad=False), Tensor.ones(sz, dtype=dtypes.float32, requires_grad=False)
self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.long, requires_grad=False)
def __call__(self, x:Tensor) -> Tensor:
batch_mean, batch_var = super().calc_stats(x.cast(dtypes.float32))
if self.track_running_stats and Tensor.training:
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach().cast(self.running_mean.dtype))
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * x.numel()/(x.numel()-x.shape[1]) * batch_var.detach().cast(self.running_var.dtype))
self.num_batches_tracked += 1
return x.cast(dtypes.float32).batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt()).cast(x.dtype)
class Conv2dNormalRetinaNet(nn.Conv2d):
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
bias:bool=True, prior_prob:float|None=None):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.weight = Tensor.normal(*self.weight.shape, std=0.01, dtype=dtypes.float32)
if bias:
if prior_prob:
prior_prob = Tensor(prior_prob, device=self.bias.device, dtype=dtypes.float32).expand(*self.bias.shape)
self.bias = -(((1 - prior_prob) / prior_prob).log())
else: self.bias = Tensor.zeros_like(self.bias, dtype=dtypes.float32)
def __call__(self, x:Tensor) -> Tensor:
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
groups=self.groups, stride=self.stride, padding=self.padding)
class Conv2dKaimingUniformRetinaNet(nn.Conv2d):
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
bias:bool=True):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.weight = Tensor.kaiming_uniform(*self.weight.shape, a=1, dtype=dtypes.float32)
if bias: self.bias = Tensor.zeros_like(self.bias, dtype=dtypes.float32)
def __call__(self, x:Tensor) -> Tensor:
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
groups=self.groups, stride=self.stride, padding=self.padding)
class Conv2dRetinaNet(nn.Conv2d):
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
bias:bool=True):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale, dtype=dtypes.float32)
self.bias: Tensor|None = Tensor.uniform(out_channels, low=-scale, high=scale, dtype=dtypes.float32) if bias else None
def __call__(self, x:Tensor) -> Tensor:
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
groups=self.groups, stride=self.stride, dilation=self.dilation, padding=self.padding)

View File

@@ -1,11 +1,11 @@
import os, time, math, functools
import os, time, math, functools, random
from pathlib import Path
import multiprocessing
from tinygrad import Device, GlobalCounters, Tensor, TinyJit, dtypes
from tinygrad.helpers import getenv, BEAM, WINO, round_up, diskcache_clear, FUSE_CONV_BW
from tinygrad.nn.state import get_parameters, get_state_dict, safe_load, safe_save
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup
from tinygrad.nn.optim import LAMB, LARS, SGD, OptimizerGroup, Adam
from extra.lr_scheduler import LRSchedulerGroup
from examples.mlperf.helpers import get_training_state, load_training_state
@@ -343,8 +343,243 @@ def train_resnet():
safe_save(get_training_state(model, optimizer_group, scheduler_group), fn)
def train_retinanet():
# TODO: Retinanet
pass
from contextlib import redirect_stdout
from examples.mlperf.dataloader import batch_load_retinanet
from examples.mlperf.initializers import FrozenBatchNorm2dRetinaNet, Conv2dNormalRetinaNet, Conv2dKaimingUniformRetinaNet, Linear, Conv2dRetinaNet
from extra.datasets.openimages import MLPERF_CLASSES, BASEDIR, download_dataset, normalize
from extra.models import resnet
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from tinygrad.helpers import colored, Context
from typing import Iterator
import extra.models.retinanet as retinanet
import numpy as np
config, target_metric = {}, 0.34
NUM_CLASSES = len(MLPERF_CLASSES)
BASE_DIR = getenv("BASE_DIR", BASEDIR)
BENCHMARK = getenv("BENCHMARK")
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 6))]
for x in GPUS: Device[x]
print(f"training on {GPUS}")
def _freeze_backbone_layers(backbone:resnet.ResNet, trainable_layers:int):
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
for k, v in get_state_dict(backbone).items():
if all([not k.startswith(layer) for layer in layers_to_train]):
v.requires_grad = False
def _data_get(it:Iterator[tuple[Tensor, ...]], val:bool=False):
if val:
x, img_ids, img_sizes, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), img_ids, img_sizes, cookie
x, y_boxes, y_labels, matches, anchors, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), y_boxes.shard(GPUS, axis=0), y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), anchors.shard(GPUS, axis=0), cookie
@TinyJit
def _train_step(model, optim, loss_scaler, x, **kwargs):
with Context(BEAM=TRAIN_BEAM):
optim.zero_grad()
losses = model(normalize(x, GPUS), **kwargs)
loss = sum([l for l in losses.values()])
(loss * loss_scaler).backward()
for t in optim.params: t.grad = t.grad / loss_scaler
optim.step()
return loss.realize(), losses
@TinyJit
def _eval_step(model, x, **kwargs):
with Context(BEAM=EVAL_BEAM):
out = model(normalize(x, GPUS), **kwargs)
return out.realize()
# ** hyperparameters **
config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
config["bs"] = BS = getenv("BS", 16 * len(GPUS) if dtypes.default_float == dtypes.float16 else 12 * len(GPUS))
config["eval_bs"] = EVAL_BS = getenv("EVAL_BS", BS)
config["epochs"] = EPOCHS = getenv("EPOCHS", 4)
config["train_beam"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
config["eval_beam"] = EVAL_BEAM = getenv("EVAL_BEAM", BEAM.value)
config["lr"] = lr = getenv("LR", 8.5e-5 * (BS / 96))
config["loss_scaler"] = loss_scaler = getenv("LOSS_SCALER", 2**11 if dtypes.default_float == dtypes.float16 else 1.0)
config["default_float"] = dtypes.default_float.name
if SEED: Tensor.manual_seed(SEED)
# ** model initializers **
resnet.BatchNorm = FrozenBatchNorm2dRetinaNet
resnet.Linear = Linear
resnet.Conv2d = Conv2dRetinaNet
retinanet.ConvHead = Conv2dNormalRetinaNet
retinanet.ConvClassificationHeadLogits = functools.partial(Conv2dNormalRetinaNet, prior_prob=0.01)
retinanet.ConvFPN = Conv2dKaimingUniformRetinaNet
# ** model setup **
backbone = resnet.ResNeXt50_32X4D(num_classes=None)
backbone.load_from_pretrained()
_freeze_backbone_layers(backbone, 3)
model = retinanet.RetinaNet(backbone, num_classes=NUM_CLASSES)
params = get_parameters(model)
for p in params: p.to_(GPUS)
step_times, start_epoch = [], 0
# ** optimizer **
optim = Adam(params, lr=lr)
# ** dataset **
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
coco_val = COCOeval(cocoGt=val_dataset, iouType="bbox")
# ** lr scheduler **
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS
config["steps_in_val_epoch"] = steps_in_val_epoch = (round_up(len(val_dataset.imgs.keys()), BS) // BS)
start_iter = start_epoch * steps_in_train_epoch
# ** initialize wandb **
if (WANDB:=getenv("WANDB")):
import wandb
wandb.init(config=config, project="MLPerf-RetinaNet")
print(f"training with batch size {BS} for {EPOCHS} epochs")
for e in range(start_epoch, EPOCHS):
# ** training loop **
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED)
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, _data_get(it)
prev_cookies = []
st = time.perf_counter()
while proc is not None:
GlobalCounters.reset()
x, y_bboxes, y_labels, matches, anchors, proc = proc
loss, losses = _train_step(model, optim, loss_scaler, x, labels=y_labels, matches=matches, anchors=anchors, bboxes=y_bboxes)
pt = time.perf_counter()
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
next_proc = _data_get(it)
except StopIteration:
next_proc = None
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss = loss.item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
if not math.isfinite(loss):
print("loss is nan")
return
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {losses['classification_loss'].item():5.4f} classification loss, {losses['regression_loss'].item():5.4f} regression loss, "
f"{optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
)
if WANDB:
wandb.log({"lr": optim.lr.numpy(), "train/loss": loss, "train/classification_loss": losses["classification_loss"].item(), "train/regression_loss": losses["regression_loss"].item(),
"train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": e + (i + 1) / steps_in_train_epoch})
st = cl
prev_cookies.append(proc)
proc, next_proc = next_proc, None # return old cookie
i += 1
if i == BENCHMARK:
assert not math.isnan(loss)
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * EPOCHS / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")
# if we are doing beam search, run the first eval too
if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break
return
# ** eval loop **
if getenv("RESET_STEP", 1): _train_step.reset()
with Tensor.train(mode=False), Tensor.test():
val_dataloader = batch_load_retinanet(val_dataset, (val:=True), Path(BASE_DIR), batch_size=EVAL_BS, shuffle=False, seed=SEED)
it = iter(tqdm(val_dataloader, total=steps_in_val_epoch))
i, proc = 0, _data_get(it, val=val)
eval_times, prev_cookies = [], []
val_img_ids, val_imgs, ncats, narea = [], [], len(coco_val.params.catIds), len(coco_val.params.areaRng)
while proc is not None:
GlobalCounters.reset()
st = time.time()
out, img_ids, img_sizes, proc = _eval_step(model, (x:=proc[0])).numpy(), proc[1], proc[2], proc[3]
out = model.postprocess_detections(out, input_size=x.shape[1:3], orig_image_sizes=img_sizes)
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
for i, prediction in enumerate(out) for box, score, label in zip(*prediction.values())]
with redirect_stdout(None):
coco_val.cocoDt = val_dataset.loadRes(coco_results)
coco_val.params.imgIds = img_ids
coco_val.evaluate()
val_img_ids.extend(img_ids)
val_imgs.append(np.array(coco_val.evalImgs).reshape(ncats, narea, len(img_ids)))
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:
next_proc = _data_get(it, val=val)
except StopIteration:
next_proc = None
prev_cookies.append(proc)
proc, next_proc = next_proc, None
i += 1
if i == BENCHMARK:
return
et = time.time()
eval_times.append(et - st)
if getenv("RESET_STEP", 1): _eval_step.reset()
total_fw_time = sum(eval_times) / len(eval_times)
coco_val.params.imgIds = val_img_ids
coco_val._paramsEval.imgIds = val_img_ids
coco_val.evalImgs = list(np.concatenate(val_imgs, -1).flatten())
coco_val.accumulate()
coco_val.summarize()
val_metric = coco_val.stats[0]
tqdm.write(f"eval time: {total_fw_time:.2f}, eval metric: {val_metric:.4f}")
if WANDB:
wandb.log({"eval/forward_time": total_fw_time, "eval/metric": val_metric, "epoch": e + 1})
if val_metric >= target_metric:
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}", color="green"))
break
def train_unet3d():
"""