From d844ecee272957d3dd106d3089b5af345eb6c4e4 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 15 Jan 2022 21:46:54 -0800 Subject: [PATCH] track running stats by default, and detach --- tinygrad/nn.py | 7 ++++--- tinygrad/tensor.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tinygrad/nn.py b/tinygrad/nn.py index d0ad7fa40f..feba126221 100644 --- a/tinygrad/nn.py +++ b/tinygrad/nn.py @@ -2,7 +2,7 @@ from tinygrad.tensor import Tensor import numpy as np class BatchNorm2D: - def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=False, momentum=0.1): + def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1): assert affine == True self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum @@ -13,8 +13,9 @@ class BatchNorm2D: def __call__(self, x): if self.track_running_stats or Tensor.training: - batch_mean = x.mean(axis=(0,2,3)) - y = (x - batch_mean.reshape(shape=[1, -1, 1, 1])) + x_detached = x.detach() + batch_mean = x_detached.mean(axis=(0,2,3)) + y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1])) batch_var = (y*y).mean(axis=(0,2,3)) if self.track_running_stats: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2f0f47320a..bd6c6c3296 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -169,7 +169,7 @@ class Tensor: return ret def detach(self): - return Tensor(self.data, device=self.device) + return Tensor(self.data, device=self.device, requires_grad=False) # ***** non first class ops *****