make BatchNorm work for 2D and 3D (#5477)

* make BatchNorm work for 2D and 3D

* beautiful mnist shouldn't use BatchNorm2d
This commit is contained in:
George Hotz
2024-07-14 11:39:58 -07:00
committed by GitHub
parent e41ab66653
commit a9f5a764dc
3 changed files with 25 additions and 15 deletions

View File

@@ -9,10 +9,10 @@ class Model:
self.layers: List[Callable[[Tensor], Tensor]] = [
nn.Conv2d(1, 32, 5), Tensor.relu,
nn.Conv2d(32, 32, 5), Tensor.relu,
nn.BatchNorm2d(32), Tensor.max_pool2d,
nn.BatchNorm(32), Tensor.max_pool2d,
nn.Conv2d(32, 64, 3), Tensor.relu,
nn.Conv2d(64, 64, 3), Tensor.relu,
nn.BatchNorm2d(64), Tensor.max_pool2d,
nn.BatchNorm(64), Tensor.max_pool2d,
lambda x: x.flatten(1), nn.Linear(576, 10)]
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)

View File

@@ -6,7 +6,7 @@ from tinygrad import Tensor, Device, TinyJit
from tinygrad.helpers import CI, Context
from tinygrad.ops import MetaOps
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
from tinygrad.nn import BatchNorm2d, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm
from tinygrad.nn.state import load_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
@@ -27,12 +27,12 @@ class TestNN(unittest.TestCase):
torch_loss = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=smoothing, ignore_index=ignore_index)(torch_input, torch_target)
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
def test_batchnorm2d(self, training=False):
def test_batchnorm2d(self, training=False, threed=False):
with Tensor.train(training):
szs = [4, 8, 16, 32]
for sz in szs:
# create in tinygrad
bn = BatchNorm2d(sz, eps=1e-5, track_running_stats=training)
bn = BatchNorm(sz, eps=1e-5, track_running_stats=training)
bn.weight = Tensor.randn(sz)
bn.bias = Tensor.randn(sz)
bn.running_mean = Tensor.randn(sz)
@@ -41,7 +41,10 @@ class TestNN(unittest.TestCase):
# create in torch
with torch.no_grad():
tbn = torch.nn.BatchNorm2d(sz).eval()
if threed:
tbn = torch.nn.BatchNorm3d(sz).eval()
else:
tbn = torch.nn.BatchNorm2d(sz).eval()
tbn.training = training
tbn.weight[:] = torch.tensor(bn.weight.numpy())
tbn.bias[:] = torch.tensor(bn.bias.numpy())
@@ -52,7 +55,10 @@ class TestNN(unittest.TestCase):
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
# trial
inn = Tensor.randn(2, sz, 3, 3)
if threed:
inn = Tensor.randn(2, sz, 3, 3, 3)
else:
inn = Tensor.randn(2, sz, 3, 3)
# in tinygrad
outt = bn(inn)
@@ -68,6 +74,9 @@ class TestNN(unittest.TestCase):
def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)
def test_batchnorm3d(self): self.test_batchnorm2d(False, True)
def test_batchnorm3d_training(self): self.test_batchnorm2d(True, True)
def test_batchnorm_axis(self):
sz = (2, 4, 3, 2, 2)
x = Tensor.randn(sz)

View File

@@ -4,9 +4,9 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
from tinygrad.nn import optim, state, datasets # noqa: F401
class BatchNorm2d:
class BatchNorm:
"""
Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension).
Applies Batch Normalization over a 2D or 3D input.
- Described: https://paperswithcode.com/method/batch-normalization
- Paper: https://arxiv.org/abs/1502.03167v3
@@ -20,7 +20,7 @@ class BatchNorm2d:
```
```python exec="true" source="above" session="tensor" result="python"
norm = nn.BatchNorm2d(3)
norm = nn.BatchNorm(3)
t = Tensor.rand(2, 3, 4, 4)
print(t.mean().item(), t.std().item())
```
@@ -39,13 +39,14 @@ class BatchNorm2d:
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
def __call__(self, x:Tensor):
shape_mask = [1, -1, *([1]*(x.ndim-2))]
if Tensor.training:
# This requires two full memory accesses to x
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
batch_mean = x.mean(axis=(0,2,3))
y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
batch_var = (y*y).mean(axis=(0,2,3))
batch_mean = x.mean(axis=(reduce_axes:=tuple(x for x in range(x.ndim) if x != 1)))
y = (x - batch_mean.detach().reshape(shape=shape_mask)) # d(var)/d(mean) = 0
batch_var = (y*y).mean(axis=reduce_axes)
batch_invstd = batch_var.add(self.eps).pow(-0.5)
# NOTE: wow, this is done all throughout training in most PyTorch models
@@ -56,9 +57,9 @@ class BatchNorm2d:
else:
batch_mean = self.running_mean
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
batch_invstd = self.running_var.reshape(shape=shape_mask).expand(x.shape).add(self.eps).rsqrt()
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
BatchNorm2d = BatchNorm3d = BatchNorm
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
"""