diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index f6194aaed9..759d042ec9 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -49,7 +49,7 @@ class UnsyncedBatchNorm: # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm batch_mean = x.mean(axis=(1,3,4)) - y = (x - batch_mean.reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1])) + y = (x - batch_mean.detach().reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1])) # d(var)/d(mean) = 0 batch_var = (y*y).mean(axis=(1,3,4)) batch_invstd = batch_var.add(self.eps).pow(-0.5) diff --git a/test/test_schedule.py b/test/test_schedule.py index 58570a7017..98cc2ca5a9 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -208,7 +208,7 @@ class TestSchedule(unittest.TestCase): opt.zero_grad() img_bn.backward() # this is too high - check_schedule(opt.schedule_step(), 18) + check_schedule(opt.schedule_step(), 17) def test_fold_conv_relu(self): c1 = nn.Conv2d(3,16,3) diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index a02f4b31b4..c04c1644f1 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -20,7 +20,7 @@ class BatchNorm2d: # https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh # There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm batch_mean = x.mean(axis=(0,2,3)) - y = (x - batch_mean.reshape(shape=[1, -1, 1, 1])) + y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0 batch_var = (y*y).mean(axis=(0,2,3)) batch_invstd = batch_var.add(self.eps).pow(-0.5)