mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
make BatchNorm work for 2D and 3D (#5477)
* make BatchNorm work for 2D and 3D * beautiful mnist shouldn't use BatchNorm2d
This commit is contained in:
@@ -9,10 +9,10 @@ class Model:
|
||||
self.layers: List[Callable[[Tensor], Tensor]] = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
||||
nn.BatchNorm2d(32), Tensor.max_pool2d,
|
||||
nn.BatchNorm(32), Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3), Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3), Tensor.relu,
|
||||
nn.BatchNorm2d(64), Tensor.max_pool2d,
|
||||
nn.BatchNorm(64), Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad import Tensor, Device, TinyJit
|
||||
from tinygrad.helpers import CI, Context
|
||||
from tinygrad.ops import MetaOps
|
||||
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
|
||||
from tinygrad.nn import BatchNorm2d, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm
|
||||
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -27,12 +27,12 @@ class TestNN(unittest.TestCase):
|
||||
torch_loss = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=smoothing, ignore_index=ignore_index)(torch_input, torch_target)
|
||||
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_batchnorm2d(self, training=False):
|
||||
def test_batchnorm2d(self, training=False, threed=False):
|
||||
with Tensor.train(training):
|
||||
szs = [4, 8, 16, 32]
|
||||
for sz in szs:
|
||||
# create in tinygrad
|
||||
bn = BatchNorm2d(sz, eps=1e-5, track_running_stats=training)
|
||||
bn = BatchNorm(sz, eps=1e-5, track_running_stats=training)
|
||||
bn.weight = Tensor.randn(sz)
|
||||
bn.bias = Tensor.randn(sz)
|
||||
bn.running_mean = Tensor.randn(sz)
|
||||
@@ -41,7 +41,10 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
tbn = torch.nn.BatchNorm2d(sz).eval()
|
||||
if threed:
|
||||
tbn = torch.nn.BatchNorm3d(sz).eval()
|
||||
else:
|
||||
tbn = torch.nn.BatchNorm2d(sz).eval()
|
||||
tbn.training = training
|
||||
tbn.weight[:] = torch.tensor(bn.weight.numpy())
|
||||
tbn.bias[:] = torch.tensor(bn.bias.numpy())
|
||||
@@ -52,7 +55,10 @@ class TestNN(unittest.TestCase):
|
||||
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
|
||||
|
||||
# trial
|
||||
inn = Tensor.randn(2, sz, 3, 3)
|
||||
if threed:
|
||||
inn = Tensor.randn(2, sz, 3, 3, 3)
|
||||
else:
|
||||
inn = Tensor.randn(2, sz, 3, 3)
|
||||
|
||||
# in tinygrad
|
||||
outt = bn(inn)
|
||||
@@ -68,6 +74,9 @@ class TestNN(unittest.TestCase):
|
||||
def test_batchnorm2d_training(self):
|
||||
self.test_batchnorm2d(True)
|
||||
|
||||
def test_batchnorm3d(self): self.test_batchnorm2d(False, True)
|
||||
def test_batchnorm3d_training(self): self.test_batchnorm2d(True, True)
|
||||
|
||||
def test_batchnorm_axis(self):
|
||||
sz = (2, 4, 3, 2, 2)
|
||||
x = Tensor.randn(sz)
|
||||
|
||||
@@ -4,9 +4,9 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.nn import optim, state, datasets # noqa: F401
|
||||
|
||||
class BatchNorm2d:
|
||||
class BatchNorm:
|
||||
"""
|
||||
Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension).
|
||||
Applies Batch Normalization over a 2D or 3D input.
|
||||
|
||||
- Described: https://paperswithcode.com/method/batch-normalization
|
||||
- Paper: https://arxiv.org/abs/1502.03167v3
|
||||
@@ -20,7 +20,7 @@ class BatchNorm2d:
|
||||
```
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
norm = nn.BatchNorm2d(3)
|
||||
norm = nn.BatchNorm(3)
|
||||
t = Tensor.rand(2, 3, 4, 4)
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
@@ -39,13 +39,14 @@ class BatchNorm2d:
|
||||
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
shape_mask = [1, -1, *([1]*(x.ndim-2))]
|
||||
if Tensor.training:
|
||||
# This requires two full memory accesses to x
|
||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
||||
batch_mean = x.mean(axis=(0,2,3))
|
||||
y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
|
||||
batch_var = (y*y).mean(axis=(0,2,3))
|
||||
batch_mean = x.mean(axis=(reduce_axes:=tuple(x for x in range(x.ndim) if x != 1)))
|
||||
y = (x - batch_mean.detach().reshape(shape=shape_mask)) # d(var)/d(mean) = 0
|
||||
batch_var = (y*y).mean(axis=reduce_axes)
|
||||
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
||||
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
@@ -56,9 +57,9 @@ class BatchNorm2d:
|
||||
else:
|
||||
batch_mean = self.running_mean
|
||||
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
|
||||
batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
|
||||
|
||||
batch_invstd = self.running_var.reshape(shape=shape_mask).expand(x.shape).add(self.eps).rsqrt()
|
||||
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
|
||||
BatchNorm2d = BatchNorm3d = BatchNorm
|
||||
|
||||
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user