mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Fix expand_dims and tl.full to handle scalar tensors (#2275)
This fixes a few bugs related to scalar tensors:
- `tl.full([], fill_value, dtype)` fails with `TypeError('0d block_type
is forbidden')`
- `scalar[None]` fails with `TypeError("'constexpr' object is not
iterable")`
- `scalar[None, None]` fails with `AttributeError("'dtype' object has no
attribute 'shape'")`
- `scalar.shape` returns `[1]` instead of 0-dim `[]`
- Also related, `tl.zeros_like(scalar)` returns a 1d tensor instead of
another scalar
This commit is contained in:
@@ -564,35 +564,6 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device):
|
||||
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test broadcast
|
||||
# ---------------
|
||||
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
|
||||
def test_broadcast(dtype):
|
||||
@triton.jit
|
||||
def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, M)
|
||||
offset2 = tl.arange(0, N)
|
||||
x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :])
|
||||
y = tl.load(y_ptr + offset2)
|
||||
_, y_broadcasted = tl.broadcast(x, y)
|
||||
tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted)
|
||||
|
||||
M = 32
|
||||
N = 64
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((M, N), dtype_str=dtype, rs=rs)
|
||||
y = numpy_random(N, dtype_str=dtype, rs=rs)
|
||||
_, y_broadcasted_np = np.broadcast_arrays(x, y)
|
||||
|
||||
x_tri = to_triton(x, device='cuda', dst_type=dtype)
|
||||
y_tri = to_triton(y, device='cuda', dst_type=dtype)
|
||||
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype)
|
||||
|
||||
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
|
||||
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
|
||||
|
||||
|
||||
# ---------------
|
||||
# test broadcast
|
||||
# ---------------
|
||||
@@ -621,6 +592,36 @@ def test_broadcast(dtype, device):
|
||||
broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
|
||||
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
|
||||
|
||||
# ----------
|
||||
# test slice
|
||||
# ----------
|
||||
|
||||
|
||||
def test_slice(device):
|
||||
|
||||
@triton.jit
|
||||
def slice_kernel(XBLOCK: tl.constexpr):
|
||||
data = tl.arange(0, XBLOCK)
|
||||
tl.static_assert(data.shape == [XBLOCK])
|
||||
|
||||
t = data[None, :]
|
||||
tl.static_assert(t.shape == [1, XBLOCK])
|
||||
|
||||
t = data[None, :, None]
|
||||
tl.static_assert(t.shape == [1, XBLOCK, 1])
|
||||
|
||||
scalar = tl.full([], 1, tl.int32)
|
||||
tl.static_assert(scalar.shape == [])
|
||||
|
||||
t = scalar[None]
|
||||
tl.static_assert(t.shape == [1])
|
||||
|
||||
t = scalar[None, None]
|
||||
tl.static_assert(t.shape == [1, 1])
|
||||
|
||||
slice_kernel[(1,)](XBLOCK=32)
|
||||
|
||||
|
||||
# ------------------
|
||||
# test invalid slice
|
||||
# ------------------
|
||||
@@ -669,6 +670,14 @@ def test_expand_dims(device):
|
||||
t = tl.expand_dims(offset1, (3, 1, 2))
|
||||
tl.static_assert(t.shape == [N, 1, 1, 1])
|
||||
|
||||
scalar = tl.sum(offset1)
|
||||
tl.static_assert(scalar.shape == [])
|
||||
t = tl.expand_dims(scalar, 0)
|
||||
tl.static_assert(t.shape == [1])
|
||||
|
||||
t = tl.expand_dims(scalar, -1)
|
||||
tl.static_assert(t.shape == [1])
|
||||
|
||||
N = 32
|
||||
dummy_tensor = torch.empty((), device=device)
|
||||
expand_dims_kernel[(1,)](dummy_tensor, N)
|
||||
@@ -689,6 +698,13 @@ def test_expand_dims_error_cases(device):
|
||||
t = tl.expand_dims(offset1, 1)
|
||||
t = tl.expand_dims(offset1, 2)
|
||||
|
||||
@triton.jit
|
||||
def dim_out_of_range3(dummy, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, 1)
|
||||
scalar = tl.sum(offset1)
|
||||
|
||||
t = tl.expand_dims(scalar, 1)
|
||||
|
||||
@triton.jit
|
||||
def duplicate_dim1(dummy, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, N)
|
||||
@@ -710,6 +726,9 @@ def test_expand_dims_error_cases(device):
|
||||
with pytest.raises(triton.CompilationError, match="invalid axis 2"):
|
||||
dim_out_of_range2[(1,)](dummy_tensor, N)
|
||||
|
||||
with pytest.raises(triton.CompilationError, match="invalid axis 1"):
|
||||
dim_out_of_range3[(1,)](dummy_tensor, N)
|
||||
|
||||
with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
|
||||
duplicate_dim1[(1,)](dummy_tensor, N)
|
||||
|
||||
@@ -2467,7 +2486,8 @@ def test_dot_mulbroadcastred(in_dtype, device):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16'])
|
||||
def test_full(dtype_str, device):
|
||||
@pytest.mark.parametrize("shape", [(), (1,), (128,)])
|
||||
def test_full(dtype_str, shape, device):
|
||||
if dtype_str in uint_dtypes and not hasattr(torch, dtype_str):
|
||||
# PyTorch only has unsigned 8, but not 16, 32, or 64
|
||||
dtype = getattr(torch, dtype_str[1:]) # uintx -> intx
|
||||
@@ -2478,21 +2498,28 @@ def test_full(dtype_str, device):
|
||||
@triton.jit
|
||||
def kernel_static(out):
|
||||
a = GENERATE_TEST_HERE
|
||||
tl.static_assert(a.shape == SHAPE)
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
@triton.jit
|
||||
def kernel_dynamic(out, val, dtype: tl.constexpr):
|
||||
a = tl.full((128,), val, dtype)
|
||||
a = tl.full(SHAPE, val, dtype)
|
||||
tl.static_assert(a.shape == SHAPE)
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
|
||||
kernel_static_patched = patch_kernel(kernel_static, {
|
||||
'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})",
|
||||
'SHAPE': str(list(shape)),
|
||||
})
|
||||
out_static = torch.zeros((128), dtype=dtype, device=device)
|
||||
kernel_static_patched[(1,)](out_static)
|
||||
out_dynamic = torch.zeros((128), dtype=dtype, device=device)
|
||||
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
|
||||
assert torch.all(out_static == 2)
|
||||
|
||||
kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))})
|
||||
out_dynamic = torch.zeros((128), dtype=dtype, device=device)
|
||||
kernel_dynamic_patched[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
|
||||
assert torch.all(out_dynamic == 2)
|
||||
|
||||
|
||||
|
||||
@@ -531,9 +531,7 @@ class tensor:
|
||||
# IR handle
|
||||
self.handle = handle
|
||||
# Block shape
|
||||
self.shape = (1, )
|
||||
if type.is_block():
|
||||
self.shape = type.shape
|
||||
self.shape = type.shape if type.is_block() else ()
|
||||
self.numel = 1
|
||||
for s in self.shape:
|
||||
self.numel *= s
|
||||
@@ -743,7 +741,7 @@ class tensor:
|
||||
|
||||
@builtin
|
||||
def __getitem__(self, slices, _builder=None):
|
||||
if isinstance(slices, slice):
|
||||
if isinstance(slices, (slice, constexpr)):
|
||||
slices = [slices]
|
||||
ret = self
|
||||
for dim, sl in enumerate(slices):
|
||||
|
||||
@@ -501,25 +501,31 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te
|
||||
if isinstance(value, tl.tensor):
|
||||
assert value.numel.value == 1, "only accepts size-1 tensor"
|
||||
value = cast(value, dtype, builder)
|
||||
ret_ty = tl.block_type(value.dtype, shape)
|
||||
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
|
||||
else:
|
||||
# scalar
|
||||
if dtype is None:
|
||||
raise ValueError("dtype must be specified when value is not a tensor")
|
||||
if value == 0:
|
||||
value = builder.get_null_value(dtype.to_ir(builder))
|
||||
else:
|
||||
get_value_fn = getattr(builder, f"get_{dtype.name}")
|
||||
value = get_value_fn(value)
|
||||
if dtype is None:
|
||||
raise ValueError("dtype must be specified when value is not a tensor")
|
||||
ret_ty = tl.block_type(dtype, shape)
|
||||
return tl.tensor(builder.create_splat(value, shape), ret_ty)
|
||||
value = tl.tensor(value, dtype)
|
||||
|
||||
return splat(value, shape, builder)
|
||||
|
||||
|
||||
# ===----------------------------------------------------------------------===//
|
||||
# Shape Manipulation
|
||||
# ===----------------------------------------------------------------------===//
|
||||
|
||||
def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
|
||||
assert not value.type.is_block(), "Cannot splat a block tensor"
|
||||
if len(shape) == 0:
|
||||
return value
|
||||
ret_ty = tl.block_type(value.dtype, shape)
|
||||
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
|
||||
|
||||
|
||||
def view(input: tl.tensor,
|
||||
dst_shape: List[int],
|
||||
@@ -544,8 +550,12 @@ def reshape(input: tl.tensor,
|
||||
|
||||
|
||||
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
dst_shape = list(input.type.shape)
|
||||
dst_shape = [tl._constexpr_to_value(x) for x in input.shape]
|
||||
dst_shape.insert(axis, 1)
|
||||
|
||||
if not input.type.is_block():
|
||||
return splat(input, shape=dst_shape, builder=builder)
|
||||
|
||||
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
||||
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
|
||||
|
||||
@@ -1506,7 +1516,7 @@ def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
|
||||
def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
|
||||
if len(x.shape) != len(values):
|
||||
if max(1, len(x.shape)) != len(values):
|
||||
raise ValueError("Shape of input to multiple_of does not match the length of values")
|
||||
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user