InstanceNormalization ONNX test fixed. (#870)

This commit is contained in:
skobsman
2023-05-31 00:07:44 +01:00
committed by GitHub
parent 0dab8edc97
commit 2e393f7ef2

View File

@@ -40,6 +40,12 @@ def BatchNormalization(X, scale, B, input_mean, input_var, epsilon=1e-05, moment
invstd = (input_var + epsilon)**-0.5
return X.batchnorm(scale, B, input_mean, invstd)
def InstanceNormalization(x: Tensor, scale: Tensor, bias: Tensor, epsilon=1e-05):
axis = tuple(range(2, len(x.shape)))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).pow(2).mean(axis=axis, keepdim=True).add(epsilon).pow(-0.5)
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_type=1):
assert stash_type == 1, "only float32 is supported"
axis = tuple(i for i in range(axis if axis >= 0 else len(x.shape) + axis, len(x.shape)))