start some work on classification loss

This commit is contained in:
Francis Lata
2024-12-10 16:33:16 +00:00
parent c080fcdaab
commit e5bc0c0485
3 changed files with 40 additions and 13 deletions

View File

@@ -1,3 +1,5 @@
from typing import Optional
from examples.mlperf.metrics import dice_score
from tinygrad import Tensor
@@ -6,9 +8,10 @@ def dice_ce_loss(pred, tgt):
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
return (dice + ce) / 2
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2, reduction:str = "none"):
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, mask:Optional[Tensor]=None, alpha:float = 0.25, gamma:float = 2, reduction:str = "none") -> Tensor:
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
if mask: p = p * mask
p_t = p * tgt + (1 - p) * (1 - tgt)
loss = ce_loss * ((1 - p_t) ** gamma)
@@ -16,6 +19,7 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
loss *= alpha_t
if mask: loss = loss * mask
if reduction == "mean": loss = loss.mean()
elif reduction == "sum": loss = loss.sum()
return loss

View File

@@ -370,7 +370,7 @@ def train_retinanet():
def _data_get(it):
x, y_boxes, y_labels, matches, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), y_boxes, y_labels, matches.shard(GPUS, axis=0), cookie
return x.shard(GPUS, axis=0).realize(), y_boxes, y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), cookie
def _create_lr_scheduler(optim, start_iter, warmup_iters, warmup_factor):
# TODO: refactor this a bit more so we don't have to recreate it, unlike what MLPerf script is doing
@@ -381,10 +381,11 @@ def train_retinanet():
return warmup_factor * (1 - alpha) + alpha
return LambdaLR(optim, _lr_lambda)
def _train_step(model, optim, lr_scheduler, x, matches):
@Tensor.train()
def _train_step(model, optim, lr_scheduler, x, y, matches):
optim.zero_grad()
y_hat = model(normalize(x, GPUS))
y_hat = model(normalize(x, GPUS), y=y, matches=matches)
# ** hyperparameters **
# using https://github.com/mlcommons/logging/blob/96d0acee011ba97702532dcc39e6eeaa99ebef24/mlperf_logging/rcp_checker/training_4.1.0/rcps_ssd.json#L3
@@ -435,7 +436,7 @@ def train_retinanet():
while proc is not None:
x, y_boxes, y_labels, matches, proc = proc
_train_step(model, optim, lr_scheduler, x, matches) # TODO: enable once full model has been integrated
_train_step(model, optim, lr_scheduler, x, y_labels, matches) # TODO: enable once full model has been integrated
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
try:

View File

@@ -1,8 +1,11 @@
from typing import Optional
import math
from tinygrad import Tensor, dtypes
from tinygrad.helpers import flatten, get_child
import tinygrad.nn as nn
from examples.mlperf.initializers import Conv2dNormal, Conv2dKaimingUniform
from examples.mlperf.losses import sigmoid_focal_loss
from extra.models.helpers import meshgrid, nms
from extra.models.resnet import ResNet
import numpy as np
@@ -47,10 +50,11 @@ class RetinaNet:
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
return self.head(self.backbone(x))
def __call__(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None):
return self.forward(x, y=y, matches=matches)
def forward(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None):
return self.head(self.backbone(x), y=y, matches=matches)
def load_from_pretrained(self):
model_urls = {
@@ -136,9 +140,26 @@ class ClassificationHead:
self.num_classes = num_classes
self.conv = flatten([(Conv2dNormal(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.cls_logits = Conv2dNormal(in_channels, num_anchors * num_classes, kernel_size=3, padding=1, prior_prob=prior_prob)
def __call__(self, x):
def __call__(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None):
out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
return out[0].cat(*out[1:], dim=1).sigmoid()
out = out[0].cat(*out[1:], dim=1).sigmoid()
if Tensor.training:
assert y is not None and matches is not None, "y and matches should be passed in when training"
return self._compute_loss(out, y, matches)
else:
return out
def _compute_loss(self, x:Tensor, y:Tensor, matches:Tensor) -> Tensor:
y = ((y + 1) * matches - 1).one_hot(num_classes=x.shape[-1])
# find indices for which anchors should be ignored
valid_idxs = (matches != -2).reshape(matches.shape[0], -1, 1)
import pdb; pdb.set_trace()
loss = sigmoid_focal_loss(x, y, mask=valid_idxs, reduction="sum")
class RegressionHead:
def __init__(self, in_channels, num_anchors):
@@ -152,8 +173,9 @@ class RetinaHead:
def __init__(self, in_channels, num_anchors, num_classes):
self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
self.regression_head = RegressionHead(in_channels, num_anchors)
def __call__(self, x):
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
def __call__(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None):
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x, y=y, matches=matches)
out = pred_bbox.cat(pred_class, dim=-1)
return out