mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
fix batchnorm not realizing
This commit is contained in:
@@ -11,7 +11,7 @@ class BatchNorm2d:
|
||||
self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
||||
self.num_batches_tracked = Tensor.zeros(1, requires_grad=False)
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x:Tensor):
|
||||
if Tensor.training:
|
||||
# This requires two full memory accesses to x
|
||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
@@ -25,8 +25,8 @@ class BatchNorm2d:
|
||||
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats:
|
||||
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
|
||||
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
|
||||
self.running_mean.assign((1 - self.momentum) * self.running_mean + self.momentum * batch_mean)
|
||||
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var)
|
||||
self.num_batches_tracked += 1
|
||||
else:
|
||||
batch_mean, batch_var = self.running_mean, self.running_var
|
||||
|
||||
@@ -10,6 +10,8 @@ class Optimizer:
|
||||
x.requires_grad = True
|
||||
|
||||
self.params : List[Tensor] = [x for x in params if x.requires_grad]
|
||||
self.buffers : List[Tensor] = [x for x in params if not x.requires_grad] # buffers are still realized
|
||||
self.realize()
|
||||
|
||||
# TODO: this probably shouldn't change the gradients, just the ones used by the optimizer
|
||||
def clipnorm(self, amount=1):
|
||||
@@ -24,7 +26,7 @@ class Optimizer:
|
||||
|
||||
def realize(self, extra=None):
|
||||
# TODO: corealize
|
||||
for p in extra + self.params if extra is not None else self.params:
|
||||
for p in extra + self.params + self.buffers if extra is not None else self.params + self.buffers:
|
||||
p.realize()
|
||||
|
||||
class SGD(Optimizer):
|
||||
|
||||
Reference in New Issue
Block a user