Files
tinygrad/examples/mlperf/initializers.py
David Hou 1dbf3b2b19 Benchmarks for individual resnet layers (#4182)
* resnet individual layer benchmarks!

* small

* 1 and 2

* mem_used

* no ci

* better conv print

* defaults

* prints

* adjust

* adjust

* adjust

* benchmark only one layer example

* tensor.training, zero_grad, sum instead of mean, last mem, last kernel count

* default jitcnt=1

* scale flops/kernels with jitcnt

* add note about jitcnt memory

* touchup
2024-04-16 13:53:18 -04:00

36 lines
2.1 KiB
Python

import math
from tinygrad import Tensor, nn, dtypes
from tinygrad.helpers import prod, argfix
# rejection sampling truncated randn
def rand_truncn(*shape, dtype=None, truncstds=2, **kwargs) -> Tensor:
CNT=8
x = Tensor.randn(*(*shape, CNT), dtype=dtype, **kwargs)
ctr = Tensor.arange(CNT).reshape((1,) * len(x.shape[:-1]) + (CNT,)).expand(x.shape)
take = (x.abs() <= truncstds).where(ctr, CNT).min(axis=-1, keepdim=True) # set to 0 if no good samples
return (ctr == take).where(x, 0).sum(axis=-1)
# https://github.com/keras-team/keras/blob/v2.15.0/keras/initializers/initializers.py#L1026-L1065
def he_normal(*shape, a: float = 0.00, **kwargs) -> Tensor:
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) / 0.87962566103423978
return std * rand_truncn(*shape, **kwargs)
class Conv2dHeNormal(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.in_channels, self.out_channels = in_channels, out_channels # for testing
self.weight = he_normal(out_channels, in_channels//groups, *self.kernel_size, a=0.0, dtype=dtypes.float32)
if bias: self.bias = self.bias.cast(dtypes.float32)
def __call__(self, x: Tensor):
return x.conv2d(self.weight.cast(dtypes.default_float), self.bias.cast(dtypes.default_float) if self.bias is not None else None,
padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
class Linear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features, bias=bias)
self.weight = Tensor.normal((out_features, in_features), mean=0.0, std=0.01, dtype=dtypes.float32)
if bias: self.bias = Tensor.zeros(out_features, dtype=dtypes.float32)
def __call__(self, x:Tensor):
return x.linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)