mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
track running stats by default, and detach
This commit is contained in:
@@ -2,7 +2,7 @@ from tinygrad.tensor import Tensor
|
||||
import numpy as np
|
||||
|
||||
class BatchNorm2D:
|
||||
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=False, momentum=0.1):
|
||||
def __init__(self, sz, eps=1e-5, affine=True, track_running_stats=True, momentum=0.1):
|
||||
assert affine == True
|
||||
self.eps, self.track_running_stats, self.momentum = eps, track_running_stats, momentum
|
||||
|
||||
@@ -13,8 +13,9 @@ class BatchNorm2D:
|
||||
|
||||
def __call__(self, x):
|
||||
if self.track_running_stats or Tensor.training:
|
||||
batch_mean = x.mean(axis=(0,2,3))
|
||||
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
x_detached = x.detach()
|
||||
batch_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
batch_var = (y*y).mean(axis=(0,2,3))
|
||||
|
||||
if self.track_running_stats:
|
||||
|
||||
@@ -169,7 +169,7 @@ class Tensor:
|
||||
return ret
|
||||
|
||||
def detach(self):
|
||||
return Tensor(self.data, device=self.device)
|
||||
return Tensor(self.data, device=self.device, requires_grad=False)
|
||||
|
||||
# ***** non first class ops *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user