workaround for num_programs() to get correct value (#271)

* workaround for num_programs to return correct value for issue #259.
This commit is contained in:
Shucai Xiao
2023-07-25 15:46:19 -05:00
committed by GitHub
parent 91c66f4a7f
commit 1cccf14f62
2 changed files with 44 additions and 0 deletions

View File

@@ -122,6 +122,34 @@ def test_empty_kernel(dtype_x, device='cuda'):
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
def test_program_functions(dtype_x = 'float16', device='cuda'):
SIZE = 128
# os.environ["TRITON_DEBUG"] = "1"
@triton.jit
def kernel(X, Y, SIZE: tl.constexpr):
pid_x = tl.program_id(0)
pid_y = tl.program_id(1)
pid_z = tl.program_id(2)
num_pids_x = tl.num_programs(0)
num_pids_y = tl.num_programs(1)
num_pids_z = tl.num_programs(2)
pid = pid_z * (num_pids_x * num_pids_y) + pid_y * num_pids_x + pid_x
val = (pid_x == 0)
val = val and (pid_y == 0)
val = val and (pid_z == 0 or pid_z == 1)
val = val and (num_pids_x == 1)
val = val and (num_pids_y == 1)
val = val and (num_pids_z == 2)
tl.store(Y + pid, val)
check_type_supported(dtype_x)
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)
gold = torch.ones((2,), device = device, dtype = torch.bool)
y = torch.zeros((2,), device = device, dtype = torch.bool)
kernel[(1, 1, 2)](x, y, SIZE=SIZE, num_warps=4)
np.testing.assert_allclose(to_numpy(gold), to_numpy(y))
# generic test functions
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda'):
check_type_supported(dtype_x) # early return if dtype_x is not supported