mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Fix expand_dims and tl.full to handle scalar tensors (#2275)
This fixes a few bugs related to scalar tensors:
- `tl.full([], fill_value, dtype)` fails with `TypeError('0d block_type
is forbidden')`
- `scalar[None]` fails with `TypeError("'constexpr' object is not
iterable")`
- `scalar[None, None]` fails with `AttributeError("'dtype' object has no
attribute 'shape'")`
- `scalar.shape` returns `[1]` instead of 0-dim `[]`
- Also related, `tl.zeros_like(scalar)` returns a 1d tensor instead of
another scalar
This commit is contained in:
@@ -564,35 +564,6 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device):
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test broadcast
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
|
||||
def test_broadcast(dtype):
|
||||
@triton.jit
|
||||
def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, M)
|
||||
offset2 = tl.arange(0, N)
|
||||
x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :])
|
||||
y = tl.load(y_ptr + offset2)
|
||||
_, y_broadcasted = tl.broadcast(x, y)
|
||||
tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted)
|
||||
|
||||
M = 32
|
||||
N = 64
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((M, N), dtype_str=dtype, rs=rs)
|
||||
y = numpy_random(N, dtype_str=dtype, rs=rs)
|
||||
_, y_broadcasted_np = np.broadcast_arrays(x, y)
|
||||
|
||||
x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
||||
y_tri = to_triton(y, device='cuda', dst_type=dtype)
|
||||
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype)
|
||||
|
||||
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
|
||||
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
|
||||
|
||||
|
||||
# ---------------
|
||||
# test broadcast
|
||||
# ---------------
|
||||
@@ -621,6 +592,36 @@ def test_broadcast(dtype, device):
|
||||
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
|
||||
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
|
||||
|
||||
# ----------
|
||||
# test slice
|
||||
# ----------
|
||||
|
||||
|
||||
def test_slice(device):
|
||||
|
||||
@triton.jit
|
||||
def slice_kernel(XBLOCK: tl.constexpr):
|
||||
data = tl.arange(0, XBLOCK)
|
||||
tl.static_assert(data.shape == [XBLOCK])
|
||||
|
||||
t = data[None, :]
|
||||
tl.static_assert(t.shape == [1, XBLOCK])
|
||||
|
||||
t = data[None, :, None]
|
||||
tl.static_assert(t.shape == [1, XBLOCK, 1])
|
||||
|
||||
scalar = tl.full([], 1, tl.int32)
|
||||
tl.static_assert(scalar.shape == [])
|
||||
|
||||
t = scalar[None]
|
||||
tl.static_assert(t.shape == [1])
|
||||
|
||||
t = scalar[None, None]
|
||||
tl.static_assert(t.shape == [1, 1])
|
||||
|
||||
slice_kernel[(1,)](XBLOCK=32)
|
||||
|
||||
|
||||
# ------------------
|
||||
# test invalid slice
|
||||
# ------------------
|
||||
@@ -669,6 +670,14 @@ def test_expand_dims(device):
|
||||
t = tl.expand_dims(offset1, (3, 1, 2))
|
||||
tl.static_assert(t.shape == [N, 1, 1, 1])
|
||||
|
||||
scalar = tl.sum(offset1)
|
||||
tl.static_assert(scalar.shape == [])
|
||||
t = tl.expand_dims(scalar, 0)
|
||||
tl.static_assert(t.shape == [1])
|
||||
|
||||
t = tl.expand_dims(scalar, -1)
|
||||
tl.static_assert(t.shape == [1])
|
||||
|
||||
N = 32
|
||||
dummy_tensor = torch.empty((), device=device)
|
||||
expand_dims_kernel[(1,)](dummy_tensor, N)
|
||||
@@ -689,6 +698,13 @@ def test_expand_dims_error_cases(device):
|
||||
t = tl.expand_dims(offset1, 1)
|
||||
t = tl.expand_dims(offset1, 2)
|
||||
|
||||
@triton.jit
|
||||
def dim_out_of_range3(dummy, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, 1)
|
||||
scalar = tl.sum(offset1)
|
||||
|
||||
t = tl.expand_dims(scalar, 1)
|
||||
|
||||
@triton.jit
|
||||
def duplicate_dim1(dummy, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, N)
|
||||
@@ -710,6 +726,9 @@ def test_expand_dims_error_cases(device):
|
||||
with pytest.raises(triton.CompilationError, match="invalid axis 2"):
|
||||
dim_out_of_range2[(1,)](dummy_tensor, N)
|
||||
|
||||
with pytest.raises(triton.CompilationError, match="invalid axis 1"):
|
||||
dim_out_of_range3[(1,)](dummy_tensor, N)
|
||||
|
||||
with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
|
||||
duplicate_dim1[(1,)](dummy_tensor, N)
|
||||
|
||||
@@ -2467,7 +2486,8 @@ def test_dot_mulbroadcastred(in_dtype, device):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16'])
|
||||
def test_full(dtype_str, device):
|
||||
@pytest.mark.parametrize("shape", [(), (1,), (128,)])
|
||||
def test_full(dtype_str, shape, device):
|
||||
if dtype_str in uint_dtypes and not hasattr(torch, dtype_str):
|
||||
# PyTorch only has unsigned 8, but not 16, 32, or 64
|
||||
dtype = getattr(torch, dtype_str[1:]) # uintx -> intx
|
||||
@@ -2478,21 +2498,28 @@ def test_full(dtype_str, device):
|
||||
@triton.jit
|
||||
def kernel_static(out):
|
||||
a = GENERATE_TEST_HERE
|
||||
tl.static_assert(a.shape == SHAPE)
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
@triton.jit
|
||||
def kernel_dynamic(out, val, dtype: tl.constexpr):
|
||||
a = tl.full((128,), val, dtype)
|
||||
a = tl.full(SHAPE, val, dtype)
|
||||
tl.static_assert(a.shape == SHAPE)
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
|
||||
kernel_static_patched = patch_kernel(kernel_static, {
|
||||
'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})",
|
||||
'SHAPE': str(list(shape)),
|
||||
})
|
||||
out_static = torch.zeros((128), dtype=dtype, device=device)
|
||||
kernel_static_patched[(1,)](out_static)
|
||||
out_dynamic = torch.zeros((128), dtype=dtype, device=device)
|
||||
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
|
||||
assert torch.all(out_static == 2)
|
||||
|
||||
kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))})
|
||||
out_dynamic = torch.zeros((128), dtype=dtype, device=device)
|
||||
kernel_dynamic_patched[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
|
||||
assert torch.all(out_dynamic == 2)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user