UnsyncedBatchNorm with synced trainable weights for hlb cifar (#3472)

* UnsyncedBatchNorm with synced trainable weights for hlb cifar

* multitensor reshape tests

* test mlb assign change axis

* E501

* argfix axis

* don't import batchnorm from hlb_cifar in test_multitensor

* pass num_devices to UnsyncedBatchNorm in test, allow UnsyncedBatchNorm to be used with LB

* add backprop test for UnsyncedBatchNorm

* break out MLB assign and reshape changes

* manually shard running mean and running var

* don't shard unless syncbn=0

* replace nn.BatchNorm2d with UnsyncedBatchNorm

* don't increment num_batches_tracked if not tracking running stats

* update tests

* oops

* Revert "oops"

This reverts commit 5e8a67a535.

* Revert "update tests"

This reverts commit 7ebf65d89a.

* Revert "don't increment num_batches_tracked if not tracking running stats"

This reverts commit 78de0ea9ee.

* Revert "replace nn.BatchNorm2d with UnsyncedBatchNorm"

This reverts commit d03da53da7.

* don't increment num_batched_tracked if not tracking running stats

* oops

* test_batchnorm_axis

* compare against torch

* types

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
David Hou
2024-02-29 19:52:07 -08:00
committed by GitHub
parent 5a6e151844
commit e5385eecfc
4 changed files with 102 additions and 38 deletions

View File

@@ -541,6 +541,30 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
out.mean().backward()
optim.step()
def test_unsynced_backprop_sync_weights(self):
from extra.lr_scheduler import OneCycleLR
from examples.hlb_cifar10 import UnsyncedBatchNorm
from tinygrad.features.multi import MultiLazyBuffer
GPUS = (d1, d2)
with Tensor.train():
conv = nn.Conv2d(3, 16, 3)
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
for p in get_parameters([conv, bn]):
if not isinstance(p.lazydata, MultiLazyBuffer):
p.shard_(GPUS)
optim = nn.optim.Adam(get_parameters([conv, bn]))
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
lr_sched.step()
fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
out = bn(conv(fake_image))
optim.zero_grad()
out.mean().backward()
optim.step()
@given(strat.sampled_from((False, True)))
def test_batchnorm(self, is_training):
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
@@ -564,13 +588,14 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
bn_ts[0].cat(*bn_ts[1:]).numpy()
def test_synced_vs_unsynced_bn(self):
from examples.hlb_cifar10 import BatchNorm, UnsyncedBatchNorm
from examples.hlb_cifar10 import UnsyncedBatchNorm
from tinygrad.nn import BatchNorm2d
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
with Tensor.train():
synced_bn = BatchNorm(8)
unsynced_bn = UnsyncedBatchNorm(8)
synced_bn = BatchNorm2d(8)
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
for p in get_parameters([synced_bn, unsynced_bn]):
p.shard_(devices)

View File

@@ -62,6 +62,24 @@ class TestNN(unittest.TestCase):
def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)
def test_batchnorm_axis(self):
sz = (2, 4, 3, 2, 2)
x = Tensor.randn(sz)
weight = Tensor.randn(2, 3)
bias = Tensor.randn(2, 3)
mean = Tensor.randn(2, 3)
invstd = Tensor.randn(2, 3)
a = (x.batchnorm(weight, bias, mean, invstd, axis=(0, 2))
.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2))
b = (x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2)
.batchnorm(weight.flatten(), bias.flatten(), mean.flatten(), invstd.flatten()))
t_x = torch.tensor(x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2).numpy())
t_weight, t_bias = torch.tensor(weight.flatten().numpy()), torch.tensor(bias.flatten().numpy())
t_mean, t_invstd = torch.tensor(mean.flatten().numpy()), torch.tensor(invstd.flatten().numpy())
torch.nn.functional.batch_norm(t_x, t_mean, 1.0 / t_invstd**2, t_weight, t_bias)
np.testing.assert_allclose(a.numpy(), b.numpy())
def test_linear(self):
def _test_linear(x, in_dim, out_dim):
# create in tinygrad