mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Add support for scalar conditions in device_assert (#1641)
This sometimes happens in TorchInductor. See https://github.com/pytorch/pytorch/pull/100880. More generally, it's useful to be able to write `tl.device_assert(False, msg)`. Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
committed by
GitHub
parent
0cd8f05e01
commit
6b1af5fe37
@@ -14,6 +14,14 @@ def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
# Trivial assert
|
||||
tl.device_assert(0 == 0, "x != 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
@@ -34,6 +42,7 @@ def test_assert(func: str):
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_assert":
|
||||
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
|
||||
Reference in New Issue
Block a user