track running stats by default, and detach

This commit is contained in:
George Hotz
2022-01-15 21:46:54 -08:00
parent d541e2a8e5
commit d844ecee27
2 changed files with 5 additions and 4 deletions

View File

@@ -2,7 +2,7 @@ from tinygrad.tensor import Tensor
import numpy as np
class BatchNorm2D:
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=False, momentum=0.1):
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
assert affine == True
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
@@ -13,8 +13,9 @@ class BatchNorm2D:
def __call__(self, x):
if self.track_running_stats or Tensor.training:
batch_mean = x.mean(axis=(0,2,3))
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
x_detached = x.detach()
batch_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
batch_var = (y*y).mean(axis=(0,2,3))
if self.track_running_stats:

View File

@@ -169,7 +169,7 @@ class Tensor:
return ret
def detach(self):
return Tensor(self.data, device=self.device)
return Tensor(self.data, device=self.device, requires_grad=False)
# ***** non first class ops *****