From e1f8c82938d5fe2fe186c172c0bdcc32e7b04c81 Mon Sep 17 00:00:00 2001 From: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com> Date: Fri, 24 Oct 2025 12:26:11 +0200 Subject: [PATCH] Onnx Layer/Group/RMS/Batch-Norm ReduceL2 fp32 intermediates for fp16 (#12109) * match onnx spec * use least_upper_dtype * promote the square * just cast before the square --- test/external/external_test_onnx_ops.py | 6 +++++- tinygrad/nn/onnx.py | 26 ++++++++++++++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/test/external/external_test_onnx_ops.py b/test/external/external_test_onnx_ops.py index 3e1cc9503f..ce62b32b58 100644 --- a/test/external/external_test_onnx_ops.py +++ b/test/external/external_test_onnx_ops.py @@ -272,6 +272,10 @@ class TestMainOnnxOps(TestOnnxOps): def test_qlinearmatmul_2D_int8_float32(self): self._run_qlinearmatmul_test(np.int8, np.float32, 2) def test_qlinearmatmul_3D_int8_float32(self): self._run_qlinearmatmul_test(np.int8, np.float32, 3) + def test_reduce_l2_half(self): + inputs = {"data": np.random.randn(1, 1, 32, 32, 32).astype(np.half)*100} + self.helper_test_single_op("ReduceL2", inputs, {}, ["reduced"]) + class TestTrainingOnnxOps(TestOnnxOps): # NOTE: ORT doesn't actually support training ops on cpu so we test using functions provided by onnx DOMAIN = AI_ONNX_PREVIEW_TRAINING_DOMAIN @@ -487,4 +491,4 @@ class TestContribOnnxOps(TestOnnxOps): self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index 46f0a193e0..4bcdde2fb2 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -5,7 +5,7 @@ from io import BufferedReader from tinygrad.nn.state import TensorIO from tinygrad.tensor import Tensor, _broadcast_shape, ReductionStr from tinygrad.helpers import getenv, DEBUG, all_same, prod, flatten, make_tuple, argsort, is_numpy_ndarray, get_single_element, polyN -from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype, truncate +from tinygrad.dtype import DType, ConstType, dtypes, _from_np_dtype, truncate, least_upper_dtype from tinygrad.device import is_dtype_supported, Device # ***** protobuf definitions ****** @@ -670,7 +670,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def ReduceL1(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSum(data.abs(), axes, keepdims, noop_with_empty_axes) def ReduceL2(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): - return ReduceSumSquare(data, axes, keepdims, noop_with_empty_axes).sqrt() + dtype = dtypes.float if data.dtype in (dtypes.float16, dtypes.bfloat16) else data.dtype + return ReduceSum(data.cast(dtype).square(), axes, keepdims, noop_with_empty_axes).sqrt().cast(data.dtype) def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): @@ -897,7 +898,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, training_mode:int=0, spatial=1, is_test=0): if training_mode: - x_detached = X.detach() + x_detached = X.detach().cast(least_upper_dtype(X.dtype, dtypes.float32)) current_mean = x_detached.mean(axis=(0,2,3)) y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) current_var = (y*y).mean(axis=(0,2,3)) @@ -906,18 +907,20 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT running_mean = input_mean * momentum + current_mean * (1 - momentum) running_var = input_var * momentum + current_var * (1 - momentum) - return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var + return X.batchnorm(scale, B, current_mean, current_invstd).cast(X.dtype),running_mean.cast(input_mean.dtype),running_var.cast(input_var.dtype) return X.batchnorm(scale, B, input_mean, (input_var + epsilon).rsqrt()) - def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): - x = x.reshape(x.shape[0], num_groups, -1).layernorm(eps=epsilon).reshape(x.shape) + def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05, stash_type:int=1): + assert stash_type == 1, "only float32 is supported" + x = x.reshape(x.shape[0], num_groups, -1).cast(dtypes.float).layernorm(eps=epsilon).cast(x.dtype).reshape(x.shape) return x * scale.reshape(1, -1, *[1] * (x.ndim-2)) + bias.reshape(1, -1, *[1] * (x.ndim-2)) def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): return GroupNormalization(x, scale, bias, num_groups=cast(int, x.shape[1]), epsilon=epsilon) def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): assert stash_type == 1, "only float32 is supported" axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) - mean = x.mean(axis=axes, keepdim=True) - return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() + mean = (x32:=x.cast(dtypes.float)).mean(axis=axes, keepdim=True) + inv_std_dev = (x32.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() + return (x32.sub(mean)*inv_std_dev).cast(x.dtype).mul(scale).add(bias), mean, inv_std_dev def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): x = x + skip if bias is not None: x = x + bias @@ -1089,9 +1092,10 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT return output, present_key, present_value, qk_matmul_return_val Attention = {OpSetId(Domain.ONNX, 1): attention_onnx, OpSetId(Domain.MICROSOFT_CONTRIB_OPS, 1): attention_contrib} - def RMSNormalization(X:Tensor, scale:Tensor, axis:int=-1, epsilon:float=1e-5): - norm = X.square().mean(axis=tuple(range(axis + X.ndim if axis < 0 else axis, X.ndim)), keepdim=True).add(epsilon).rsqrt() - return X * norm * scale + def RMSNormalization(X:Tensor, scale:Tensor, axis:int=-1, epsilon:float=1e-5, stash_type:int=1): + assert stash_type == 1, "only float32 is supported" + norm = X.cast(dtypes.float).square().mean(axis=tuple(range(axis + X.ndim if axis < 0 else axis, X.ndim)), keepdim=True).add(epsilon).rsqrt() + return X.cast(X.dtype) * norm * scale def RotaryEmbedding(X:Tensor, cos_cache:Tensor, sin_cache:Tensor, position_ids:Tensor|None=None, interleaved:int=0, num_heads:int|None=None, rotary_embedding_dim:int=0):