diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 12f05eb31b..35d2222e78 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -317,8 +317,8 @@ class Tensor: return y.div((y*y).mean(axis=axis, keepdim=True).add(eps).sqrt()) def batchnorm(self, weight:Tensor, bias:Tensor, mean:Tensor, invstd:Tensor): - self = (self - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1]) - return self.mul(invstd.reshape(shape=[1, -1, 1, 1])) + bias.reshape(shape=[1, -1, 1, 1]) + x = (self - mean.reshape(shape=[1, -1, 1, 1])) * weight.reshape(shape=[1, -1, 1, 1]) + return x.mul(invstd.reshape(shape=[1, -1, 1, 1])) + bias.reshape(shape=[1, -1, 1, 1]) # An instantiation of the Function is the Context class Function: