[FRONTEND] Add support for scalar conditions in device_assert (#1641)

This sometimes happens in TorchInductor. See
https://github.com/pytorch/pytorch/pull/100880.
More generally, it's useful to be able to write `tl.device_assert(False,
msg)`.

Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
Mario Lezcano Casado
2023-05-10 07:05:00 +01:00
committed by GitHub
parent 0cd8f05e01
commit 6b1af5fe37
2 changed files with 13 additions and 0 deletions

View File

@@ -14,6 +14,14 @@ def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
# Trivial assert
tl.device_assert(0 == 0, "x != 0")
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
@@ -34,6 +42,7 @@ def test_assert(func: str):
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_assert":
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])
elif func == "assert":
kernel_assert[(1,)](x, y, BLOCK=shape[0])
elif func == "static_assert":