From a8a62bc08e41bafd696c1c537e48a30f80ce30ac Mon Sep 17 00:00:00 2001 From: Douglas Nyberg <71201490+Douglas-Nyberg@users.noreply.github.com> Date: Thu, 4 Dec 2025 03:53:47 -0500 Subject: [PATCH] add max/min reduction support to ScatterND (#13562) --- test/external/external_test_onnx_backend.py | 3 --- tinygrad/nn/onnx.py | 5 +++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index d467d24ece..235d169fa1 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -184,9 +184,6 @@ backend_test.exclude('test_ai_onnx_ml_label_encoder_tensor_mapping_cpu') # bad d backend_test.exclude('test_if_opt_cpu') # ValueError: 13 is not a valid AttributeType backend_test.exclude('test_if_seq_cpu') # NotImplementedError: op='SequenceConstruct' is not supported -backend_test.exclude('test_scatternd_min_cpu') # min not yet supported -backend_test.exclude('test_scatternd_max_cpu') # max not yet supported - # regression from removing StrEnum in Domain backend_test.exclude('test_adam_cpu') backend_test.exclude('test_gradient_of_add_and_mul_cpu') diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index dd4d373d52..35d1925864 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -1158,7 +1158,7 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) - def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): + def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul", "max", "min"]='none'): assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] x = x.contiguous() for index, u in zip(indices.split(1, 0), updates.split(1, 0)): @@ -1167,7 +1167,8 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT if reduction == "none": x[i] = u elif reduction == "add": x[i] += u elif reduction == "mul": x[i] *= u - else: raise NotImplementedError("reduction doesn't support max or min") + elif reduction == "max": x[i] = x[i].maximum(u) + elif reduction == "min": x[i] = x[i].minimum(u) return x def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul", "min", "max"]="none"):