mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
* add support for a custom BASEDIR for openimages download * make export step faster * add focal loss * update model_eval with new dataloader * generate_anchors in tinygrad * update initializers for model * small cleanup * revert isin enhancements * recursively go through backbone layers to freeze them * add optimizer * minor cleanup * start dataloader work with input images * add first transform for train set * reuse existing prepare_target * continue with dataloader implementation * add dataloader * separate out KiTS19 dataset test cases * create mock data samples for test * add dataloader + test * cleanup dataloader test and revert shm path * trim dataloader related code needed from ref * got dataloader with normalize working * update image to be float32 * add back normalization and negate it in test * clean up reference dataset implementation + ruff changes * add validation set test * add proper training loop over the training dataset * add LambdaLR support * add LR scheduler and the start of training step * get forward call to model work and setup multi-GPU * already passed device * return matches from dataloader * hotfix for dataloader typo causing some hang * start some work on classification loss * update focal loss to support masking * add missing test and cleanup focal loss * cleanup unit tests * remove masking support for sigmoid_focal_loss * make ClassificationHead loss work * cleanups + fix dataloader tests * remove sigmoid when computing loss * make anchors use Tensors * simplify anchors batching * revert anchors to use np * implement regression loss * fix regression loss * cleanup losses * move BoxCoder to MLPerf helpers * revert helper changes * fixes after helper refactor cleanup * add tests for l1_loss * start re-enabling training step * minor cleanup * add pycocotools to testing dependencies * make training work * adjust regression loss to mask after L1 loss is calculated * reduce img and lbl sizes by half for KiTS19 dataset tests * Revert "reduce img and lbl sizes by half for KiTS19 dataset tests" This reverts commitd115b0c664. * temporarily disable openimages dataset tests to debug CI * enable openimages dataset test and create samples once * temporarily disable openimages validation set test * reenable test and add some debugging to the test * add boto3 testing dependencies * add pandas to testing dependencies * This reverts commit467704fec6. * reenable test * move sample creation to setup * realize boxcoder's encoding * add wandb * fix wandb resuming feature * move anchors as part of dataloader * fix dtype for anchor inside dataloader and fix horizontal flip transformation * add support for BENCHMARK * set seed * debug dataset test failuire * Revert "debug dataset test failuire" This reverts commit1b2f9d7f50. * fix dataloader script * do not realize when sharding model weights * setup openimages samples differently * create the necessary samples per test case * enable lr scheduler and fix benchmark timing * add jit to the training loop * add checkpointing and training resume capabilities * refactor on training loop and start the work on val looop * add debug logging for dataloader test * debug test * assert boxes again * update validation dataloader and more cleanups * fix validation test case * add multi device support to retinanet eval * fix issue with realized on dataloader * remove optional disk tensors in dataloader * remove verbose debugging on datasets test * put back parallel testing and remove img_ids Tensor from dataloader * cleanup train and validation dataloader * return validation targets in dataloader * cleanup boxes and labels in dataloader * fix img_ids repeating its values * remove unnecessary targets from validation dataloader * add validation loop to training script * adjust LR to be the ratio of the batch size * minor cleanups * remove frozen layers from optimizer's params * hyperparameter adjustments and cleanups * model init, hyperparam, and data preprocessing updates * no need to return loaded keys for resnet * fix train script * update loss calculation for regresionhead and some cleanups * add JIT reset support * add nan check during training * Revert "add nan check during training" This reverts commitddf1f0d5dd. * Revert "Revert "add nan check during training"" This reverts commitb7b2943197. * some typing cleanups * update seeding on dataloader and the start of training script * undo changse * undo more changes * more typing fixes * minor cleanups * update dataloader seed * hotfix: log metric and move target metric check outside of CKPT * check for CKPT when target metric is reached before saving * add TRAIN_BEAM and EVAL_BEAM * minor cleanup * update hyperparams and add support for EVAL_BS * add green coloring to metric reached statement * initial work to support f16 * update model initializers to be monkeypatched * update layers to support float32 weight loading + float16 training * don't return loss that's scaled * run eval on benchmark beam * move BEAM to their respective steps * update layers to be compatible with fp16 * end BENCHMARK after first eval * cleanups and adjust learning rate for fp16 * remove duplicated files from test * revert losses changes * Revert "revert losses changes" This reverts commitaebccf93ac. * go back to old LR * cast batchnorm to float32 * set new loss scaler default value for float16 * remove LambdaLRScheduler * remove runner and use dataloader on eval * fix retinanet eval with new dataloader * remove unused import * revert lr_scheduler updates * use BS=96 with new learning rate * rename module initializers * more cleanups on training loop * remove contig from optim.step * simplify sum when computing loss
130 lines
8.8 KiB
Python
130 lines
8.8 KiB
Python
import math
|
|
from typing import Union
|
|
|
|
from tinygrad import Tensor, nn, dtypes
|
|
from tinygrad.helpers import prod, argfix
|
|
|
|
# rejection sampling truncated randn
|
|
def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor:
|
|
CNT=8
|
|
x = Tensor.randn(*(*shape, CNT), dtype=dtype, **kwargs)
|
|
ctr = Tensor.arange(CNT).reshape((1,) * len(x.shape[:-1]) + (CNT,)).expand(x.shape)
|
|
take = (x.abs() <= truncstds).where(ctr, CNT).min(axis=-1, keepdim=True) # set to 0 if no good samples
|
|
return (ctr == take).where(x, 0).sum(axis=-1)
|
|
|
|
# https://github.com/keras-team/keras/blob/v2.15.0/keras/initializers/initializers.py#L1026-L1065
|
|
def he_normal(*shape, a: float = 0.00, **kwargs) -> Tensor:
|
|
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) / 0.87962566103423978
|
|
return std * rand_truncn(*shape, **kwargs)
|
|
|
|
class Conv2dHeNormal(nn.Conv2d):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
|
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
|
self.in_channels, self.out_channels = in_channels, out_channels # for testing
|
|
self.weight = he_normal(out_channels, in_channels//groups, *self.kernel_size, a=0.0, dtype=dtypes.float32)
|
|
if bias: self.bias = self.bias.cast(dtypes.float32)
|
|
def __call__(self, x: Tensor):
|
|
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
|
padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
|
|
|
class Linear(nn.Linear):
|
|
def __init__(self, in_features, out_features, bias=True):
|
|
super().__init__(in_features, out_features, bias=bias)
|
|
self.weight = Tensor.normal((out_features, in_features), mean=0.0, std=0.01, dtype=dtypes.float32)
|
|
if bias: self.bias = Tensor.zeros(out_features, dtype=dtypes.float32)
|
|
def __call__(self, x:Tensor):
|
|
return x.linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
|
|
|
|
class LinearBert(nn.Linear):
|
|
def __init__(self, in_features, out_features, bias=True, std=0.02):
|
|
self.weight = std * rand_truncn(out_features, in_features, dtype=dtypes.float32)
|
|
self.bias = Tensor.zeros(out_features, dtype=dtypes.float32) if bias else None
|
|
|
|
def __call__(self, x:Tensor):
|
|
return x.cast(dtypes.default_float).linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
|
|
|
|
class EmbeddingBert(nn.Embedding):
|
|
def __init__(self, vocab_size:int, embed_size:int, std=0.02):
|
|
self.vocab_sz, self.embed_sz = vocab_size, embed_size
|
|
self.weight = std * rand_truncn(vocab_size, embed_size, dtype=dtypes.float32)
|
|
|
|
def __call__(self, idx:Tensor) -> Tensor:
|
|
if idx.numel() == 0: return Tensor.empty(idx.shape+(self.embed_sz,), dtype=self.weight.dtype, device=self.weight.device)
|
|
arange_shp, weight_shp, big_shp = (1, 1, self.vocab_sz, 1), (1, 1, self.vocab_sz, self.embed_sz), idx.shape+(self.vocab_sz, self.embed_sz,)
|
|
if not hasattr(self, 'arange'): self.arange = Tensor.arange(self.vocab_sz, requires_grad=False, device=self.weight.device).reshape(arange_shp)
|
|
arange, idx, vals = self.arange.expand(big_shp), idx.reshape(idx.shape+(1, 1,)).expand(big_shp), self.weight.cast(dtypes.default_float).reshape(weight_shp).expand(big_shp)
|
|
# TODO: contiguous() here because the embedding dropout creates different asts on each device, and search becomes very slow.
|
|
# Should fix with fixing random ast on multi device, and fuse arange to make embedding fast.
|
|
return (arange == idx).mul(vals).sum(2, dtype=vals.dtype).contiguous()
|
|
|
|
class LayerNormBert:
|
|
def __init__(self, normalized_shape:Union[int, tuple[int, ...]], eps:float=1e-12, elementwise_affine:bool=True):
|
|
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
|
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
|
self.weight, self.bias = (Tensor.ones(*self.normalized_shape, dtype=dtypes.float32), Tensor.zeros(*self.normalized_shape, dtype=dtypes.float32)) if elementwise_affine else (None, None)
|
|
|
|
def __call__(self, x:Tensor):
|
|
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
|
xn = x.cast(dtypes.float32).layernorm(eps=self.eps, axis=self.axis).cast(x.dtype)
|
|
if not self.elementwise_affine: return xn
|
|
return (xn * self.weight.cast(dtypes.default_float) + self.bias.cast(dtypes.default_float))
|
|
|
|
class FrozenBatchNorm2dRetinaNet(nn.BatchNorm2d):
|
|
def __init__(self, sz:int, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
|
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
|
|
|
self.weight = Tensor.ones(sz, dtype=dtypes.float32, requires_grad=False) if affine else None
|
|
self.bias = Tensor.zeros(sz, dtype=dtypes.float32, requires_grad=False) if affine else None
|
|
|
|
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, dtype=dtypes.float32, requires_grad=False), Tensor.ones(sz, dtype=dtypes.float32, requires_grad=False)
|
|
self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.long, requires_grad=False)
|
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
batch_mean, batch_var = super().calc_stats(x.cast(dtypes.float32))
|
|
if self.track_running_stats and Tensor.training:
|
|
self.running_mean.assign((1-self.momentum) * self.running_mean + self.momentum * batch_mean.detach().cast(self.running_mean.dtype))
|
|
self.running_var.assign((1-self.momentum) * self.running_var + self.momentum * x.numel()/(x.numel()-x.shape[1]) * batch_var.detach().cast(self.running_var.dtype))
|
|
self.num_batches_tracked += 1
|
|
return x.cast(dtypes.float32).batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt()).cast(x.dtype)
|
|
|
|
class Conv2dNormalRetinaNet(nn.Conv2d):
|
|
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
|
|
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
|
|
bias:bool=True, prior_prob:float|None=None):
|
|
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
|
self.weight = Tensor.normal(*self.weight.shape, std=0.01, dtype=dtypes.float32)
|
|
if bias:
|
|
if prior_prob:
|
|
prior_prob = Tensor(prior_prob, device=self.bias.device, dtype=dtypes.float32).expand(*self.bias.shape)
|
|
self.bias = -(((1 - prior_prob) / prior_prob).log())
|
|
else: self.bias = Tensor.zeros_like(self.bias, dtype=dtypes.float32)
|
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
|
groups=self.groups, stride=self.stride, padding=self.padding)
|
|
|
|
class Conv2dKaimingUniformRetinaNet(nn.Conv2d):
|
|
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
|
|
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
|
|
bias:bool=True):
|
|
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
|
self.weight = Tensor.kaiming_uniform(*self.weight.shape, a=1, dtype=dtypes.float32)
|
|
if bias: self.bias = Tensor.zeros_like(self.bias, dtype=dtypes.float32)
|
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
|
groups=self.groups, stride=self.stride, padding=self.padding)
|
|
|
|
class Conv2dRetinaNet(nn.Conv2d):
|
|
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
|
|
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
|
|
bias:bool=True):
|
|
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
|
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
|
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale, dtype=dtypes.float32)
|
|
self.bias: Tensor|None = Tensor.uniform(out_channels, low=-scale, high=scale, dtype=dtypes.float32) if bias else None
|
|
|
|
def __call__(self, x:Tensor) -> Tensor:
|
|
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
|
|
groups=self.groups, stride=self.stride, dilation=self.dilation, padding=self.padding)
|