mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
start some work on classification loss
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from examples.mlperf.metrics import dice_score
|
||||
from tinygrad import Tensor
|
||||
|
||||
@@ -6,9 +8,10 @@ def dice_ce_loss(pred, tgt):
|
||||
dice = (1.0 - dice_score(pred, tgt, argmax=False, to_one_hot_x=False)).mean()
|
||||
return (dice + ce) / 2
|
||||
|
||||
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float = 2, reduction:str = "none"):
|
||||
def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, mask:Optional[Tensor]=None, alpha:float = 0.25, gamma:float = 2, reduction:str = "none") -> Tensor:
|
||||
assert reduction in ["mean", "sum", "none"], f"unsupported reduction {reduction}"
|
||||
p, ce_loss = pred.sigmoid(), pred.binary_crossentropy_logits(tgt, reduction="none")
|
||||
if mask: p = p * mask
|
||||
p_t = p * tgt + (1 - p) * (1 - tgt)
|
||||
loss = ce_loss * ((1 - p_t) ** gamma)
|
||||
|
||||
@@ -16,6 +19,7 @@ def sigmoid_focal_loss(pred:Tensor, tgt:Tensor, alpha:float = 0.25, gamma:float
|
||||
alpha_t = alpha * tgt + (1 - alpha) * (1 - tgt)
|
||||
loss *= alpha_t
|
||||
|
||||
if mask: loss = loss * mask
|
||||
if reduction == "mean": loss = loss.mean()
|
||||
elif reduction == "sum": loss = loss.sum()
|
||||
return loss
|
||||
|
||||
@@ -370,7 +370,7 @@ def train_retinanet():
|
||||
|
||||
def _data_get(it):
|
||||
x, y_boxes, y_labels, matches, cookie = next(it)
|
||||
return x.shard(GPUS, axis=0).realize(), y_boxes, y_labels, matches.shard(GPUS, axis=0), cookie
|
||||
return x.shard(GPUS, axis=0).realize(), y_boxes, y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), cookie
|
||||
|
||||
def _create_lr_scheduler(optim, start_iter, warmup_iters, warmup_factor):
|
||||
# TODO: refactor this a bit more so we don't have to recreate it, unlike what MLPerf script is doing
|
||||
@@ -381,10 +381,11 @@ def train_retinanet():
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
return LambdaLR(optim, _lr_lambda)
|
||||
|
||||
def _train_step(model, optim, lr_scheduler, x, matches):
|
||||
@Tensor.train()
|
||||
def _train_step(model, optim, lr_scheduler, x, y, matches):
|
||||
optim.zero_grad()
|
||||
|
||||
y_hat = model(normalize(x, GPUS))
|
||||
y_hat = model(normalize(x, GPUS), y=y, matches=matches)
|
||||
|
||||
# ** hyperparameters **
|
||||
# using https://github.com/mlcommons/logging/blob/96d0acee011ba97702532dcc39e6eeaa99ebef24/mlperf_logging/rcp_checker/training_4.1.0/rcps_ssd.json#L3
|
||||
@@ -435,7 +436,7 @@ def train_retinanet():
|
||||
|
||||
while proc is not None:
|
||||
x, y_boxes, y_labels, matches, proc = proc
|
||||
_train_step(model, optim, lr_scheduler, x, matches) # TODO: enable once full model has been integrated
|
||||
_train_step(model, optim, lr_scheduler, x, y_labels, matches) # TODO: enable once full model has been integrated
|
||||
|
||||
if len(prev_cookies) == getenv("STORE_COOKIES", 1): prev_cookies = [] # free previous cookies after gpu work has been enqueued
|
||||
try:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from typing import Optional
|
||||
|
||||
import math
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import flatten, get_child
|
||||
import tinygrad.nn as nn
|
||||
from examples.mlperf.initializers import Conv2dNormal, Conv2dKaimingUniform
|
||||
from examples.mlperf.losses import sigmoid_focal_loss
|
||||
from extra.models.helpers import meshgrid, nms
|
||||
from extra.models.resnet import ResNet
|
||||
import numpy as np
|
||||
@@ -47,10 +50,11 @@ class RetinaNet:
|
||||
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
|
||||
self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios)
|
||||
|
||||
def __call__(self, x):
|
||||
return self.forward(x)
|
||||
def forward(self, x):
|
||||
return self.head(self.backbone(x))
|
||||
def __call__(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None):
|
||||
return self.forward(x, y=y, matches=matches)
|
||||
|
||||
def forward(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None):
|
||||
return self.head(self.backbone(x), y=y, matches=matches)
|
||||
|
||||
def load_from_pretrained(self):
|
||||
model_urls = {
|
||||
@@ -136,9 +140,26 @@ class ClassificationHead:
|
||||
self.num_classes = num_classes
|
||||
self.conv = flatten([(Conv2dNormal(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
|
||||
self.cls_logits = Conv2dNormal(in_channels, num_anchors * num_classes, kernel_size=3, padding=1, prior_prob=prior_prob)
|
||||
def __call__(self, x):
|
||||
|
||||
def __call__(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None):
|
||||
out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
|
||||
return out[0].cat(*out[1:], dim=1).sigmoid()
|
||||
out = out[0].cat(*out[1:], dim=1).sigmoid()
|
||||
|
||||
if Tensor.training:
|
||||
assert y is not None and matches is not None, "y and matches should be passed in when training"
|
||||
return self._compute_loss(out, y, matches)
|
||||
else:
|
||||
return out
|
||||
|
||||
def _compute_loss(self, x:Tensor, y:Tensor, matches:Tensor) -> Tensor:
|
||||
y = ((y + 1) * matches - 1).one_hot(num_classes=x.shape[-1])
|
||||
|
||||
# find indices for which anchors should be ignored
|
||||
valid_idxs = (matches != -2).reshape(matches.shape[0], -1, 1)
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
loss = sigmoid_focal_loss(x, y, mask=valid_idxs, reduction="sum")
|
||||
|
||||
|
||||
class RegressionHead:
|
||||
def __init__(self, in_channels, num_anchors):
|
||||
@@ -152,8 +173,9 @@ class RetinaHead:
|
||||
def __init__(self, in_channels, num_anchors, num_classes):
|
||||
self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
|
||||
self.regression_head = RegressionHead(in_channels, num_anchors)
|
||||
def __call__(self, x):
|
||||
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
|
||||
|
||||
def __call__(self, x:Tensor, y:Optional[Tensor] = None, matches:Optional[Tensor] = None):
|
||||
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x, y=y, matches=matches)
|
||||
out = pred_bbox.cat(pred_class, dim=-1)
|
||||
return out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user