fix binop, other tests failure (#723)

* fix binop, other tests failure

* that was a bad idea

* better layernorm

* inference kernel count tests

* new style reshape pushing

* fixup replacement

* 199 kernels is okay. fix flops

* push reshape through unaryops only

* GRAPH=2 draws the phantom ops

* found resnet issue

* non working test

* mul is cheaper than div

* OPT inflation

* SHUFFLE_PAD_OPS in OPT=2
This commit is contained in:
George Hotz
2023-03-22 18:15:07 -07:00
committed by GitHub
parent d6f4219952
commit b12b60af20
13 changed files with 234 additions and 105 deletions

View File

@@ -21,7 +21,6 @@ class BatchNorm2d:
y = (x_detached - batch_mean.reshape(shape=[1, -1, 1, 1]))
batch_var = (y*y).mean(axis=(0,2,3))
batch_invstd = batch_var.add(self.eps).pow(-0.5)
self.batch_invstd = None
# NOTE: wow, this is done all throughout training in most PyTorch models
if self.track_running_stats:
@@ -29,11 +28,9 @@ class BatchNorm2d:
self.running_var.assign((1 - self.momentum) * self.running_var + self.momentum * batch_var)
self.num_batches_tracked += 1
else:
batch_mean, batch_var = self.running_mean, self.running_var
# NOTE: this can be precomputed for static inference. if you manually update running_var, you have to reset this
if not hasattr(self, "batch_invstd") or not self.batch_invstd:
self.batch_invstd = batch_var.add(self.eps).pow(-0.5)
batch_invstd = self.batch_invstd
batch_mean = self.running_mean
# NOTE: this can be precomputed for static inference. we expand it here so it fuses
batch_invstd = self.running_var.reshape(1, -1, 1, 1).expand(x.shape).add(self.eps).rsqrt()
return x.batchnorm(self.weight, self.bias, batch_mean, batch_invstd)