RetinaNet model type annotations and loss functions (#9822)

* add type annotations and loss functions for training

* combine sum of multiple dims inside loss functions
This commit is contained in:
Francis Lata
2025-04-10 00:31:37 -04:00
committed by GitHub
parent 06a928b341
commit eb2e59db42

View File

@@ -1,9 +1,14 @@
import math
from tinygrad import Tensor, dtypes
from tinygrad.helpers import flatten, get_child
import tinygrad.nn as nn
from examples.mlperf.helpers import generate_anchors, BoxCoder
from examples.mlperf.losses import sigmoid_focal_loss, l1_loss
from extra.models.resnet import ResNet
import tinygrad.nn as nn
import numpy as np
ConvFPN = ConvHead = ConvClassificationHeadLogits = nn.Conv2d
def nms(boxes, scores, thresh=0.5):
x1, y1, x2, y2 = np.rollaxis(boxes, 1)
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
@@ -30,26 +35,8 @@ def decode_bbox(offsets, anchors):
pred_x2, pred_y2 = pred_cx + 0.5 * pred_w, pred_cy + 0.5 * pred_h
return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
def generate_anchors(input_size, grid_sizes, scales, aspect_ratios):
assert len(scales) == len(aspect_ratios) == len(grid_sizes)
anchors = []
for s, ar, gs in zip(scales, aspect_ratios, grid_sizes):
s, ar = np.array(s), np.array(ar)
h_ratios = np.sqrt(ar)
w_ratios = 1 / h_ratios
ws = (w_ratios[:, None] * s[None, :]).reshape(-1)
hs = (h_ratios[:, None] * s[None, :]).reshape(-1)
base_anchors = (np.stack([-ws, -hs, ws, hs], axis=1) / 2).round()
stride_h, stride_w = input_size[0] // gs[0], input_size[1] // gs[1]
shifts_x, shifts_y = np.meshgrid(np.arange(gs[1]) * stride_w, np.arange(gs[0]) * stride_h)
shifts_x = shifts_x.reshape(-1)
shifts_y = shifts_y.reshape(-1)
shifts = np.stack([shifts_x, shifts_y, shifts_x, shifts_y], axis=1, dtype=np.float32)
anchors.append((shifts[:, None] + base_anchors[None, :]).reshape(-1, 4))
return anchors
class RetinaNet:
def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None):
def __init__(self, backbone:ResNet, num_classes:int=264, num_anchors:int=9, scales:list[int]|None=None, aspect_ratios:list[float]|None=None):
assert isinstance(backbone, ResNet)
scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
@@ -58,12 +45,12 @@ class RetinaNet:
self.backbone = ResNetFPN(backbone)
self.head = RetinaHead(self.backbone.out_channels, num_anchors=num_anchors, num_classes=num_classes)
self.anchor_gen = lambda input_size: generate_anchors(input_size, self.backbone.compute_grid_sizes(input_size), scales, aspect_ratios)
def __call__(self, x):
return self.forward(x)
def forward(self, x):
return self.head(self.backbone(x))
def __call__(self, x:Tensor, **kwargs):
return self.forward(x, **kwargs)
def forward(self, x:Tensor, **kwargs):
return self.head(self.backbone(x), **kwargs)
def load_from_pretrained(self):
model_urls = {
@@ -82,7 +69,7 @@ class RetinaNet:
# predictions: (BS, (H1W1+...+HmWm)A, 4 + K)
def postprocess_detections(self, predictions, input_size=(800, 800), image_sizes=None, orig_image_sizes=None, score_thresh=0.05, topk_candidates=1000, nms_thresh=0.5):
anchors = self.anchor_gen(input_size)
anchors = generate_anchors(input_size)
grid_sizes = self.backbone.compute_grid_sizes(input_size)
split_idx = np.cumsum([int(self.num_anchors * sz[0] * sz[1]) for sz in grid_sizes[:-1]])
detections = []
@@ -145,33 +132,73 @@ class RetinaNet:
return detections
class ClassificationHead:
def __init__(self, in_channels, num_anchors, num_classes):
def __init__(self, in_channels:int, num_anchors:int, num_classes:int):
self.num_classes = num_classes
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
def __call__(self, x):
self.conv = flatten([(ConvHead(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.cls_logits = ConvClassificationHeadLogits(in_channels, num_anchors * num_classes, kernel_size=3, padding=1)
def __call__(self, x:Tensor, labels:Tensor|None=None, matches:Tensor|None=None):
out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
return out[0].cat(*out[1:], dim=1).sigmoid()
out = out[0].cat(*out[1:], dim=1)
if Tensor.training:
assert labels is not None and matches is not None, "labels and matches should be passed in when training"
return self._compute_loss(out.cast(dtypes.float32), labels, matches)
return out.sigmoid()
def _compute_loss(self, x:Tensor, labels:Tensor, matches:Tensor) -> Tensor:
labels = ((labels + 1) * (fg_idxs := matches >= 0) - 1).one_hot(num_classes=x.shape[-1])
valid_idxs = (matches != -2).reshape(matches.shape[0], -1, 1)
loss = valid_idxs.where(sigmoid_focal_loss(x, labels), 0).sum((-1, -2))
loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0]
return loss
class RegressionHead:
def __init__(self, in_channels, num_anchors):
self.conv = flatten([(nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, padding=1)
def __call__(self, x):
def __init__(self, in_channels:int, num_anchors:int, box_coder:BoxCoder|None=None):
self.conv = flatten([(ConvHead(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.bbox_reg = ConvHead(in_channels, num_anchors * 4, kernel_size=3, padding=1)
if box_coder is None:
box_coder = BoxCoder((1.0, 1.0, 1.0, 1.0), apply_to_remove=False)
self.box_coder = box_coder
def __call__(self, x:Tensor, bboxes:Tensor|None=None, matches:Tensor|None=None, anchors:Tensor|None=None):
out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x]
return out[0].cat(*out[1:], dim=1)
out = out[0].cat(*out[1:], dim=1)
if Tensor.training:
assert bboxes is not None and matches is not None and anchors is not None, "bboxes, matches, and anchors should be passed in when training"
return self._compute_loss(out, bboxes, matches, anchors)
return out
def _compute_loss(self, x:Tensor, bboxes:Tensor, matches:Tensor, anchors:Tensor) -> Tensor:
mask = (fg_idxs := matches >= 0).reshape(matches.shape[0], -1, 1)
x = x * mask
tgt = self.box_coder.encode(bboxes, anchors) * mask
loss = l1_loss(x, tgt).sum((-1, -2))
loss = (loss / fg_idxs.sum(-1)).sum() / matches.shape[0]
return loss
class RetinaHead:
def __init__(self, in_channels, num_anchors, num_classes):
def __init__(self, in_channels:int, num_anchors:int, num_classes:int):
self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
self.regression_head = RegressionHead(in_channels, num_anchors)
def __call__(self, x):
def __call__(self, x:Tensor, **kwargs) -> Tensor|dict[str, Tensor]:
if Tensor.training:
return {
"classification_loss": self.classification_head(x, labels=kwargs["labels"], matches=kwargs["matches"]),
"regression_loss": self.regression_head(x, bboxes=kwargs["bboxes"], matches=kwargs["matches"], anchors=kwargs["anchors"])
}
pred_bbox, pred_class = self.regression_head(x), self.classification_head(x)
out = pred_bbox.cat(pred_class, dim=-1)
return out
class ResNetFPN:
def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
def __init__(self, resnet:ResNet, out_channels:int=256, returned_layers:list[int]=[2, 3, 4]):
self.out_channels = out_channels
self.body = resnet
in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers]
@@ -181,7 +208,7 @@ class ResNetFPN:
def compute_grid_sizes(self, input_size):
return np.ceil(np.array(input_size)[None, :] / 2 ** np.arange(3, 8)[:, None])
def __call__(self, x):
def __call__(self, x:Tensor):
out = self.body.bn1(self.body.conv1(x)).relu()
out = out.pad([1,1,1,1]).max_pool2d((3,3), 2)
out = out.sequential(self.body.layer1)
@@ -191,12 +218,12 @@ class ResNetFPN:
return self.fpn([p3, p4, p5])
class ExtraFPNBlock:
def __init__(self, in_channels, out_channels):
self.p6 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.p7 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
def __init__(self, in_channels:int, out_channels:int):
self.p6 = ConvFPN(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.p7 = ConvFPN(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.use_P5 = in_channels == out_channels
def __call__(self, p, c):
def __call__(self, p:Tensor, c:Tensor):
p5, c5 = p[-1], c[-1]
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
@@ -205,14 +232,14 @@ class ExtraFPNBlock:
return p
class FPN:
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
def __init__(self, in_channels_list:list[int], out_channels:int, extra_blocks:ExtraFPNBlock|None=None):
self.inner_blocks, self.layer_blocks = [], []
for in_channels in in_channels_list:
self.inner_blocks.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
self.layer_blocks.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
self.inner_blocks.append(ConvFPN(in_channels, out_channels, kernel_size=1))
self.layer_blocks.append(ConvFPN(out_channels, out_channels, kernel_size=3, padding=1))
self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
def __call__(self, x):
def __call__(self, x:Tensor):
last_inner = self.inner_blocks[-1](x[-1])
results = [self.layer_blocks[-1](last_inner)]
for idx in range(len(x) - 2, -1, -1):