mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
UnsyncedBatchNorm with synced trainable weights for hlb cifar (#3472)
* UnsyncedBatchNorm with synced trainable weights for hlb cifar * multitensor reshape tests * test mlb assign change axis * E501 * argfix axis * don't import batchnorm from hlb_cifar in test_multitensor * pass num_devices to UnsyncedBatchNorm in test, allow UnsyncedBatchNorm to be used with LB * add backprop test for UnsyncedBatchNorm * break out MLB assign and reshape changes * manually shard running mean and running var * don't shard unless syncbn=0 * replace nn.BatchNorm2d with UnsyncedBatchNorm * don't increment num_batches_tracked if not tracking running stats * update tests * oops * Revert "oops" This reverts commit5e8a67a535. * Revert "update tests" This reverts commit7ebf65d89a. * Revert "don't increment num_batches_tracked if not tracking running stats" This reverts commit78de0ea9ee. * Revert "replace nn.BatchNorm2d with UnsyncedBatchNorm" This reverts commitd03da53da7. * don't increment num_batched_tracked if not tracking running stats * oops * test_batchnorm_axis * compare against torch * types --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -541,6 +541,30 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
out.mean().backward()
|
||||
optim.step()
|
||||
|
||||
def test_unsynced_backprop_sync_weights(self):
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
GPUS = (d1, d2)
|
||||
|
||||
with Tensor.train():
|
||||
conv = nn.Conv2d(3, 16, 3)
|
||||
bn = UnsyncedBatchNorm(16, num_devices=len(GPUS))
|
||||
|
||||
for p in get_parameters([conv, bn]):
|
||||
if not isinstance(p.lazydata, MultiLazyBuffer):
|
||||
p.shard_(GPUS)
|
||||
optim = nn.optim.Adam(get_parameters([conv, bn]))
|
||||
lr_sched = OneCycleLR(optim, max_lr=0.1, pct_start=0.1, div_factor=100, final_div_factor=0.1, total_steps=10)
|
||||
lr_sched.step()
|
||||
|
||||
fake_image = Tensor.rand((8, 3, 32, 32)).shard(GPUS, axis=0)
|
||||
|
||||
out = bn(conv(fake_image))
|
||||
optim.zero_grad()
|
||||
out.mean().backward()
|
||||
optim.step()
|
||||
|
||||
@given(strat.sampled_from((False, True)))
|
||||
def test_batchnorm(self, is_training):
|
||||
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
||||
@@ -564,13 +588,14 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
bn_ts[0].cat(*bn_ts[1:]).numpy()
|
||||
|
||||
def test_synced_vs_unsynced_bn(self):
|
||||
from examples.hlb_cifar10 import BatchNorm, UnsyncedBatchNorm
|
||||
from examples.hlb_cifar10 import UnsyncedBatchNorm
|
||||
from tinygrad.nn import BatchNorm2d
|
||||
devices = [f"{Device.DEFAULT}:{i}" for i in range(4)]
|
||||
x = Tensor.ones(8, 8, 8, 8).contiguous().realize().shard(devices, axis=0)
|
||||
|
||||
with Tensor.train():
|
||||
synced_bn = BatchNorm(8)
|
||||
unsynced_bn = UnsyncedBatchNorm(8)
|
||||
synced_bn = BatchNorm2d(8)
|
||||
unsynced_bn = UnsyncedBatchNorm(8, num_devices=len(devices))
|
||||
|
||||
for p in get_parameters([synced_bn, unsynced_bn]):
|
||||
p.shard_(devices)
|
||||
|
||||
@@ -62,6 +62,24 @@ class TestNN(unittest.TestCase):
|
||||
def test_batchnorm2d_training(self):
|
||||
self.test_batchnorm2d(True)
|
||||
|
||||
def test_batchnorm_axis(self):
|
||||
sz = (2, 4, 3, 2, 2)
|
||||
x = Tensor.randn(sz)
|
||||
weight = Tensor.randn(2, 3)
|
||||
bias = Tensor.randn(2, 3)
|
||||
mean = Tensor.randn(2, 3)
|
||||
invstd = Tensor.randn(2, 3)
|
||||
a = (x.batchnorm(weight, bias, mean, invstd, axis=(0, 2))
|
||||
.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2))
|
||||
b = (x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2)
|
||||
.batchnorm(weight.flatten(), bias.flatten(), mean.flatten(), invstd.flatten()))
|
||||
t_x = torch.tensor(x.permute(1, 0, 2, 3, 4).reshape(4, 6, 2, 2).numpy())
|
||||
t_weight, t_bias = torch.tensor(weight.flatten().numpy()), torch.tensor(bias.flatten().numpy())
|
||||
t_mean, t_invstd = torch.tensor(mean.flatten().numpy()), torch.tensor(invstd.flatten().numpy())
|
||||
torch.nn.functional.batch_norm(t_x, t_mean, 1.0 / t_invstd**2, t_weight, t_bias)
|
||||
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
|
||||
def test_linear(self):
|
||||
def _test_linear(x, in_dim, out_dim):
|
||||
# create in tinygrad
|
||||
|
||||
Reference in New Issue
Block a user