mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
initial work to support f16
This commit is contained in:
@@ -391,13 +391,15 @@ def train_retinanet():
|
||||
return LambdaLR(optim, _lr_lambda)
|
||||
|
||||
@TinyJit
|
||||
def _train_step(model, optim, lr_scheduler, x, **kwargs):
|
||||
def _train_step(model, optim, lr_scheduler, loss_scaler, x, **kwargs):
|
||||
optim.zero_grad()
|
||||
|
||||
losses = model(normalize(x, GPUS), **kwargs)
|
||||
loss = sum([l for l in losses.values()])
|
||||
loss = (sum([l for l in losses.values()]) * loss_scaler)
|
||||
|
||||
loss.backward()
|
||||
for t in optim.params: t.grad = t.grad.contiguous() / loss_scaler
|
||||
|
||||
optim.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
@@ -410,7 +412,7 @@ def train_retinanet():
|
||||
|
||||
# ** hyperparameters **
|
||||
config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1))
|
||||
config["bs"] = BS = getenv("BS", 256)
|
||||
config["bs"] = BS = getenv("BS", 12 * len(GPUS) if dtypes.default_float == dtypes.float16 else 12 * len(GPUS)) # TODO: update float16 to use larger BS
|
||||
config["eval_bs"] = EVAL_BS = getenv("EVAL_BS", BS)
|
||||
config["epochs"] = EPOCHS = getenv("EPOCHS", 4)
|
||||
config["train_beam"] = TRAIN_BEAM = getenv("TRAIN_BEAM", BEAM.value)
|
||||
@@ -418,6 +420,8 @@ def train_retinanet():
|
||||
config["lr"] = lr = getenv("LR", 0.0001 * (BS / 256))
|
||||
config["lr_warmup_epochs"] = lr_warmup_epochs = getenv("LR_WARMUP_EPOCHS", 1)
|
||||
config["lr_warmup_factor"] = lr_warmup_factor = getenv("LR_WARMUP_FACTOR", 1e-3)
|
||||
config["loss_scaler"] = loss_scaler = getenv("LOSS_SCALER", 256.0 if dtypes.default_float == dtypes.float16 else 1.0)
|
||||
config["default_float"] = dtypes.default_float.name
|
||||
|
||||
if SEED: Tensor.manual_seed(SEED)
|
||||
|
||||
@@ -426,7 +430,8 @@ def train_retinanet():
|
||||
|
||||
# ** model setup **
|
||||
backbone = resnet.ResNeXt50_32X4D(num_classes=None)
|
||||
backbone.load_from_pretrained()
|
||||
# TODO: Figure out if casting to float16 should be done
|
||||
# backbone.load_from_pretrained()
|
||||
_freeze_backbone_layers(backbone, 3)
|
||||
|
||||
model = RetinaNet(backbone, num_classes=NUM_CLASSES)
|
||||
@@ -484,7 +489,7 @@ def train_retinanet():
|
||||
GlobalCounters.reset()
|
||||
|
||||
x, y_bboxes, y_labels, matches, anchors, proc = proc
|
||||
loss, losses = _train_step(model, optim, lr_scheduler, x, labels=y_labels, matches=matches, anchors=anchors, bboxes=y_bboxes)
|
||||
loss, losses = _train_step(model, optim, lr_scheduler, loss_scaler, x, labels=y_labels, matches=matches, anchors=anchors, bboxes=y_bboxes)
|
||||
|
||||
pt = time.perf_counter()
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import math
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import flatten, get_child
|
||||
import tinygrad.nn as nn
|
||||
from examples.mlperf.helpers import generate_anchors, BoxCoder
|
||||
from examples.mlperf.initializers import Conv2dNormal, Conv2dKaimingUniform
|
||||
from examples.mlperf.losses import sigmoid_focal_loss, l1_loss
|
||||
@@ -149,7 +148,7 @@ class ClassificationHead:
|
||||
def _compute_loss(self, x:Tensor, labels:Tensor, matches:Tensor) -> Tensor:
|
||||
labels = ((labels + 1) * (fg_idxs := matches >= 0) - 1).one_hot(num_classes=x.shape[-1])
|
||||
valid_idxs = (matches != -2).reshape(matches.shape[0], -1, 1)
|
||||
loss = valid_idxs.where(sigmoid_focal_loss(x, labels), 0).sum(-1).sum(-1)
|
||||
loss = valid_idxs.where(sigmoid_focal_loss(x.cast(dtypes.float32), labels), 0).sum(-1).sum(-1)
|
||||
loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0]
|
||||
return loss
|
||||
|
||||
@@ -174,7 +173,7 @@ class RegressionHead:
|
||||
|
||||
def _compute_loss(self, x:Tensor, bboxes:Tensor, matches:Tensor, anchors:Tensor) -> Tensor:
|
||||
mask = (fg_idxs := matches >= 0).reshape(matches.shape[0], -1, 1)
|
||||
x = x * mask
|
||||
x = x.cast(dtypes.float32) * mask
|
||||
tgt = self.box_coder.encode(bboxes, anchors) * mask
|
||||
loss = l1_loss(x, tgt).sum(-1).sum(-1)
|
||||
loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0]
|
||||
|
||||
Reference in New Issue
Block a user