mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
InstanceNormalization ONNX test fixed. (#870)
This commit is contained in:
@@ -40,6 +40,12 @@ def BatchNormalization(X, scale, B, input_mean, input_var, epsilon=1e-05, moment
|
||||
invstd = (input_var + epsilon)**-0.5
|
||||
return X.batchnorm(scale, B, input_mean, invstd)
|
||||
|
||||
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
|
||||
axis = tuple(range(2, len(x.shape)))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5)
|
||||
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
||||
|
||||
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axis = tuple(i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape)))
|
||||
|
||||
Reference in New Issue
Block a user