mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
batchnorm d(var)/d(mean) = 0 (#4430)
* d(var)/d(mean) = 0 * drop the number in test_schedule!
This commit is contained in:
@@ -49,7 +49,7 @@ class UnsyncedBatchNorm:
|
|||||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||||
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
||||||
batch_mean = x.mean(axis=(1,3,4))
|
batch_mean = x.mean(axis=(1,3,4))
|
||||||
y = (x - batch_mean.reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1]))
|
y = (x - batch_mean.detach().reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1])) # d(var)/d(mean) = 0
|
||||||
batch_var = (y*y).mean(axis=(1,3,4))
|
batch_var = (y*y).mean(axis=(1,3,4))
|
||||||
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
||||||
|
|
||||||
|
|||||||
@@ -208,7 +208,7 @@ class TestSchedule(unittest.TestCase):
|
|||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
img_bn.backward()
|
img_bn.backward()
|
||||||
# this is too high
|
# this is too high
|
||||||
check_schedule(opt.schedule_step(), 18)
|
check_schedule(opt.schedule_step(), 17)
|
||||||
|
|
||||||
def test_fold_conv_relu(self):
|
def test_fold_conv_relu(self):
|
||||||
c1 = nn.Conv2d(3,16,3)
|
c1 = nn.Conv2d(3,16,3)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class BatchNorm2d:
|
|||||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||||
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
||||||
batch_mean = x.mean(axis=(0,2,3))
|
batch_mean = x.mean(axis=(0,2,3))
|
||||||
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
|
||||||
batch_var = (y*y).mean(axis=(0,2,3))
|
batch_var = (y*y).mean(axis=(0,2,3))
|
||||||
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user