mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
workaround for num_programs() to get correct value (#271)
* workaround for num_programs to return correct value for issue #259.
This commit is contained in:
@@ -122,6 +122,34 @@ def test_empty_kernel(dtype_x, device='cuda'):
|
||||
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
|
||||
|
||||
|
||||
def test_program_functions(dtype_x = 'float16', device='cuda'):
|
||||
SIZE = 128
|
||||
# os.environ["TRITON_DEBUG"] = "1"
|
||||
@triton.jit
|
||||
def kernel(X, Y, SIZE: tl.constexpr):
|
||||
pid_x = tl.program_id(0)
|
||||
pid_y = tl.program_id(1)
|
||||
pid_z = tl.program_id(2)
|
||||
num_pids_x = tl.num_programs(0)
|
||||
num_pids_y = tl.num_programs(1)
|
||||
num_pids_z = tl.num_programs(2)
|
||||
|
||||
pid = pid_z * (num_pids_x * num_pids_y) + pid_y * num_pids_x + pid_x
|
||||
val = (pid_x == 0)
|
||||
val = val and (pid_y == 0)
|
||||
val = val and (pid_z == 0 or pid_z == 1)
|
||||
val = val and (num_pids_x == 1)
|
||||
val = val and (num_pids_y == 1)
|
||||
val = val and (num_pids_z == 2)
|
||||
tl.store(Y + pid, val)
|
||||
|
||||
check_type_supported(dtype_x)
|
||||
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)
|
||||
gold = torch.ones((2,), device = device, dtype = torch.bool)
|
||||
y = torch.zeros((2,), device = device, dtype = torch.bool)
|
||||
kernel[(1, 1, 2)](x, y, SIZE=SIZE, num_warps=4)
|
||||
np.testing.assert_allclose(to_numpy(gold), to_numpy(y))
|
||||
|
||||
# generic test functions
|
||||
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
|
||||
check_type_supported(dtype_x) # early return if dtype_x is not supported
|
||||
|
||||
Reference in New Issue
Block a user