initial work to support f16

This commit is contained in:
Francis Lata
2025-03-07 16:33:03 -08:00
parent d5d5704625
commit 64ee757076
2 changed files with 13 additions and 9 deletions

View File

@@ -391,13 +391,15 @@ def train_retinanet():
return LambdaLR(optim, _lr_lambda)
@TinyJit
def _train_step(model, optim, lr_scheduler, x, **kwargs):
def _train_step(model, optim, lr_scheduler, loss_scaler, x, **kwargs):
optim.zero_grad()
losses = model(normalize(x, GPUS), **kwargs)
loss = sum([l for l in losses.values()])
loss = (sum([l for l in losses.values()]) * loss_scaler)
loss.backward()
for t in optim.params: t.grad = t.grad.contiguous() / loss_scaler
optim.step()
lr_scheduler.step()
@@ -410,7 +412,7 @@ def train_retinanet():
# ** hyperparameters **
config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
config["bs"] = BS = getenv("BS", 256)
config["bs"] = BS = getenv("BS", 12 * len(GPUS) if dtypes.default_float == dtypes.float16 else 12 * len(GPUS)) # TODO: update float16 to use larger BS
config["eval_bs"] = EVAL_BS = getenv("EVAL_BS", BS)
config["epochs"] = EPOCHS = getenv("EPOCHS", 4)
config["train_beam"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
@@ -418,6 +420,8 @@ def train_retinanet():
config["lr"] = lr = getenv("LR", 0.0001 * (BS / 256))
config["lr_warmup_epochs"] = lr_warmup_epochs = getenv("LR_WARMUP_EPOCHS", 1)
config["lr_warmup_factor"] = lr_warmup_factor = getenv("LR_WARMUP_FACTOR", 1e-3)
config["loss_scaler"] = loss_scaler = getenv("LOSS_SCALER", 256.0 if dtypes.default_float == dtypes.float16 else 1.0)
config["default_float"] = dtypes.default_float.name
if SEED: Tensor.manual_seed(SEED)
@@ -426,7 +430,8 @@ def train_retinanet():
# ** model setup **
backbone = resnet.ResNeXt50_32X4D(num_classes=None)
backbone.load_from_pretrained()
# TODO: Figure out if casting to float16 should be done
# backbone.load_from_pretrained()
_freeze_backbone_layers(backbone, 3)
model = RetinaNet(backbone, num_classes=NUM_CLASSES)
@@ -484,7 +489,7 @@ def train_retinanet():
GlobalCounters.reset()
x, y_bboxes, y_labels, matches, anchors, proc = proc
loss, losses = _train_step(model, optim, lr_scheduler, x, labels=y_labels, matches=matches, anchors=anchors, bboxes=y_bboxes)
loss, losses = _train_step(model, optim, lr_scheduler, loss_scaler, x, labels=y_labels, matches=matches, anchors=anchors, bboxes=y_bboxes)
pt = time.perf_counter()

View File

@@ -1,7 +1,6 @@
import math
from tinygrad import Tensor
from tinygrad import Tensor, dtypes
from tinygrad.helpers import flatten, get_child
import tinygrad.nn as nn
from examples.mlperf.helpers import generate_anchors, BoxCoder
from examples.mlperf.initializers import Conv2dNormal, Conv2dKaimingUniform
from examples.mlperf.losses import sigmoid_focal_loss, l1_loss
@@ -149,7 +148,7 @@ class ClassificationHead:
def _compute_loss(self, x:Tensor, labels:Tensor, matches:Tensor) -> Tensor:
labels = ((labels + 1) * (fg_idxs := matches >= 0) - 1).one_hot(num_classes=x.shape[-1])
valid_idxs = (matches != -2).reshape(matches.shape[0], -1, 1)
loss = valid_idxs.where(sigmoid_focal_loss(x, labels), 0).sum(-1).sum(-1)
loss = valid_idxs.where(sigmoid_focal_loss(x.cast(dtypes.float32), labels), 0).sum(-1).sum(-1)
loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0]
return loss
@@ -174,7 +173,7 @@ class RegressionHead:
def _compute_loss(self, x:Tensor, bboxes:Tensor, matches:Tensor, anchors:Tensor) -> Tensor:
mask = (fg_idxs := matches >= 0).reshape(matches.shape[0], -1, 1)
x = x * mask
x = x.cast(dtypes.float32) * mask
tgt = self.box_coder.encode(bboxes, anchors) * mask
loss = l1_loss(x, tgt).sum(-1).sum(-1)
loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0]