[FRONTEND] Fix expand_dims and tl.full to handle scalar tensors (#2275)

This fixes a few bugs related to scalar tensors:
- `tl.full([], fill_value, dtype)` fails with `TypeError('0d block_type
is forbidden')`
- `scalar[None]` fails with `TypeError("'constexpr' object is not
iterable")`
- `scalar[None, None]` fails with `AttributeError("'dtype' object has no
attribute 'shape'")`
- `scalar.shape` returns `[1]` instead of 0-dim `[]`
- Also related, `tl.zeros_like(scalar)` returns a 1d tensor instead of
another scalar
This commit is contained in:
peterbell10
2023-09-12 04:59:13 +01:00
committed by GitHub
parent bf4f9375a7
commit ab9da3b2b8
3 changed files with 81 additions and 46 deletions

View File

@@ -564,35 +564,6 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas)
# ---------------
# test broadcast
# ---------------
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
def test_broadcast(dtype):
@triton.jit
def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):
offset1 = tl.arange(0, M)
offset2 = tl.arange(0, N)
x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :])
y = tl.load(y_ptr + offset2)
_, y_broadcasted = tl.broadcast(x, y)
tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted)
M = 32
N = 64
rs = RandomState(17)
x = numpy_random((M, N), dtype_str=dtype, rs=rs)
y = numpy_random(N, dtype_str=dtype, rs=rs)
_, y_broadcasted_np = np.broadcast_arrays(x, y)
x_tri = to_triton(x, device='cuda', dst_type=dtype)
y_tri = to_triton(y, device='cuda', dst_type=dtype)
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype)
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
# ---------------
# test broadcast
# ---------------
@@ -621,6 +592,36 @@ def test_broadcast(dtype, device):
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
# ----------
# test slice
# ----------
def test_slice(device):
@triton.jit
def slice_kernel(XBLOCK: tl.constexpr):
data = tl.arange(0, XBLOCK)
tl.static_assert(data.shape == [XBLOCK])
t = data[None, :]
tl.static_assert(t.shape == [1, XBLOCK])
t = data[None, :, None]
tl.static_assert(t.shape == [1, XBLOCK, 1])
scalar = tl.full([], 1, tl.int32)
tl.static_assert(scalar.shape == [])
t = scalar[None]
tl.static_assert(t.shape == [1])
t = scalar[None, None]
tl.static_assert(t.shape == [1, 1])
slice_kernel[(1,)](XBLOCK=32)
# ------------------
# test invalid slice
# ------------------
@@ -669,6 +670,14 @@ def test_expand_dims(device):
t = tl.expand_dims(offset1, (3, 1, 2))
tl.static_assert(t.shape == [N, 1, 1, 1])
scalar = tl.sum(offset1)
tl.static_assert(scalar.shape == [])
t = tl.expand_dims(scalar, 0)
tl.static_assert(t.shape == [1])
t = tl.expand_dims(scalar, -1)
tl.static_assert(t.shape == [1])
N = 32
dummy_tensor = torch.empty((), device=device)
expand_dims_kernel[(1,)](dummy_tensor, N)
@@ -689,6 +698,13 @@ def test_expand_dims_error_cases(device):
t = tl.expand_dims(offset1, 1)
t = tl.expand_dims(offset1, 2)
@triton.jit
def dim_out_of_range3(dummy, N: tl.constexpr):
offset1 = tl.arange(0, 1)
scalar = tl.sum(offset1)
t = tl.expand_dims(scalar, 1)
@triton.jit
def duplicate_dim1(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
@@ -710,6 +726,9 @@ def test_expand_dims_error_cases(device):
with pytest.raises(triton.CompilationError, match="invalid axis 2"):
dim_out_of_range2[(1,)](dummy_tensor, N)
with pytest.raises(triton.CompilationError, match="invalid axis 1"):
dim_out_of_range3[(1,)](dummy_tensor, N)
with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
duplicate_dim1[(1,)](dummy_tensor, N)
@@ -2467,7 +2486,8 @@ def test_dot_mulbroadcastred(in_dtype, device):
@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16'])
def test_full(dtype_str, device):
@pytest.mark.parametrize("shape", [(), (1,), (128,)])
def test_full(dtype_str, shape, device):
if dtype_str in uint_dtypes and not hasattr(torch, dtype_str):
# PyTorch only has unsigned 8, but not 16, 32, or 64
dtype = getattr(torch, dtype_str[1:]) # uintx -> intx
@@ -2478,21 +2498,28 @@ def test_full(dtype_str, device):
@triton.jit
def kernel_static(out):
a = GENERATE_TEST_HERE
tl.static_assert(a.shape == SHAPE)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
@triton.jit
def kernel_dynamic(out, val, dtype: tl.constexpr):
a = tl.full((128,), val, dtype)
a = tl.full(SHAPE, val, dtype)
tl.static_assert(a.shape == SHAPE)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
kernel_static_patched = patch_kernel(kernel_static, {
'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})",
'SHAPE': str(list(shape)),
})
out_static = torch.zeros((128), dtype=dtype, device=device)
kernel_static_patched[(1,)](out_static)
out_dynamic = torch.zeros((128), dtype=dtype, device=device)
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_static == 2)
kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))})
out_dynamic = torch.zeros((128), dtype=dtype, device=device)
kernel_dynamic_patched[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_dynamic == 2)