mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[ROCM] Fix assert helper
This commit is contained in:
@@ -52,7 +52,7 @@ def test_assert(func: str):
|
||||
kernel_device_assert_scalar[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "no_debug":
|
||||
# TRITON_DEBUG=True can override the debug flag
|
||||
kernel_device_assert_no_debug[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_device_assert_no_debug[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
@@ -116,11 +116,11 @@ def test_assert_nested(caller: str, callee: str):
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if caller == "none":
|
||||
kernel_device_assert_nested[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
kernel_device_assert_nested[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "true":
|
||||
kernel_device_assert_nested_true[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
kernel_device_assert_nested_true[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
|
||||
elif caller == "false":
|
||||
kernel_device_assert_nested_false[(1,)](x, y, BLOCK=shape[0], jit_debug=callee)
|
||||
kernel_device_assert_nested_false[(1,)](x, y, num_warps=2, BLOCK=shape[0], jit_debug=callee)
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user