Onnx Layer/Group/RMS/Batch-Norm ReduceL2 fp32 intermediates for fp16 (#12109)

* match onnx spec

* use least_upper_dtype

* promote the square

* just cast before the square
This commit is contained in:
Sieds Lykles
2025-10-24 12:26:11 +02:00
committed by GitHub
parent 0bde87d8d7
commit e1f8c82938
2 changed files with 20 additions and 12 deletions

View File

@@ -272,6 +272,10 @@ class TestMainOnnxOps(TestOnnxOps):
def test_qlinearmatmul_2D_int8_float32(self): self._run_qlinearmatmul_test(np.int8, np.float32, 2)
def test_qlinearmatmul_3D_int8_float32(self): self._run_qlinearmatmul_test(np.int8, np.float32, 3)
def test_reduce_l2_half(self):
inputs = {"data": np.random.randn(1, 1, 32, 32, 32).astype(np.half)*100}
self.helper_test_single_op("ReduceL2", inputs, {}, ["reduced"])
class TestTrainingOnnxOps(TestOnnxOps):
# NOTE: ORT doesn't actually support training ops on cpu so we test using functions provided by onnx
DOMAIN = AI_ONNX_PREVIEW_TRAINING_DOMAIN
@@ -487,4 +491,4 @@ class TestContribOnnxOps(TestOnnxOps):
self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs)
if __name__ == "__main__":
unittest.main()
unittest.main()

View File

@@ -5,7 +5,7 @@ from io import BufferedReader
from tinygrad.nn.state import TensorIO
from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr
from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray, get_single_element, polyN
from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype, truncate
from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype, truncate, least_upper_dtype
from tinygrad.device import is_dtype_supported, Device
# ***** protobuf definitions ******
@@ -670,7 +670,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes)
def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt()
dtype = dtypes.float if data.dtype in (dtypes.float16, dtypes.bfloat16) else data.dtype
return ReduceSum(data.cast(dtype).square(), axes, keepdims, noop_with_empty_axes).sqrt().cast(data.dtype)
def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
@@ -897,7 +898,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
training_mode:int=0, spatial=1, is_test=0):
if training_mode:
x_detached = X.detach()
x_detached = X.detach().cast(least_upper_dtype(X.dtype, dtypes.float32))
current_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
current_var = (y*y).mean(axis=(0,2,3))
@@ -906,18 +907,20 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
running_mean = input_mean * momentum + current_mean * (1 - momentum)
running_var = input_var * momentum + current_var * (1 - momentum)
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
return X.batchnorm(scale, B, current_mean, current_invstd).cast(X.dtype),running_mean.cast(input_mean.dtype),running_var.cast(input_var.dtype)
return X.batchnorm(scale, B, input_mean, (input_var + epsilon).rsqrt())
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
x = x.reshape(x.shape[0], num_groups, -1).layernorm(eps=epsilon).reshape(x.shape)
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05, stash_type:int=1):
assert stash_type == 1, "only float32 is supported"
x = x.reshape(x.shape[0], num_groups, -1).cast(dtypes.float).layernorm(eps=epsilon).cast(x.dtype).reshape(x.shape)
return x * scale.reshape(1, -1, *[1] * (x.ndim-2)) + bias.reshape(1, -1, *[1] * (x.ndim-2))
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
return GroupNormalization(x, scale, bias, num_groups=cast(int, x.shape[1]), epsilon=epsilon)
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
assert stash_type == 1, "only float32 is supported"
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
mean = x.mean(axis=axes, keepdim=True)
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
mean = (x32:=x.cast(dtypes.float)).mean(axis=axes, keepdim=True)
inv_std_dev = (x32.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
return (x32.sub(mean)*inv_std_dev).cast(x.dtype).mul(scale).add(bias), mean, inv_std_dev
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
x = x + skip
if bias is not None: x = x + bias
@@ -1089,9 +1092,10 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
return output, present_key, present_value, qk_matmul_return_val
Attention = {OpSetId(Domain.ONNX, 1): attention_onnx, OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1): attention_contrib}
def RMSNormalization(X:Tensor, scale:Tensor, axis:int=-1, epsilon:float=1e-5):
norm = X.square().mean(axis=tuple(range(axis + X.ndim if axis < 0 else axis, X.ndim)), keepdim=True).add(epsilon).rsqrt()
return X * norm * scale
def RMSNormalization(X:Tensor, scale:Tensor, axis:int=-1, epsilon:float=1e-5, stash_type:int=1):
assert stash_type == 1, "only float32 is supported"
norm = X.cast(dtypes.float).square().mean(axis=tuple(range(axis + X.ndim if axis < 0 else axis, X.ndim)), keepdim=True).add(epsilon).rsqrt()
return X.cast(X.dtype) * norm * scale
def RotaryEmbedding(X:Tensor, cos_cache:Tensor, sin_cache:Tensor, position_ids:Tensor|None=None, interleaved:int=0, num_heads:int|None=None,
rotary_embedding_dim:int=0):