mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
fix binop, other tests failure (#723)
* fix binop, other tests failure * that was a bad idea * better layernorm * inference kernel count tests * new style reshape pushing * fixup replacement * 199 kernels is okay. fix flops * push reshape through unaryops only * GRAPH=2 draws the phantom ops * found resnet issue * non working test * mul is cheaper than div * OPT inflation * SHUFFLE_PAD_OPS in OPT=2
This commit is contained in:
@@ -21,7 +21,6 @@ class BatchNorm2d:
|
||||
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
batch_var = (y*y).mean(axis=(0,2,3))
|
||||
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
||||
self.batch_invstd = None
|
||||
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats:
|
||||
@@ -29,11 +28,9 @@ class BatchNorm2d:
|
||||
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var)
|
||||
self.num_batches_tracked += 1
|
||||
else:
|
||||
batch_mean, batch_var = self.running_mean, self.running_var
|
||||
# NOTE: this can be precomputed for static inference. if you manually update running_var, you have to reset this
|
||||
if not hasattr(self, "batch_invstd") or not self.batch_invstd:
|
||||
self.batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
||||
batch_invstd = self.batch_invstd
|
||||
batch_mean = self.running_mean
|
||||
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
|
||||
batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
|
||||
|
||||
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user