ROCM IFU: Fix test_if

This commit is contained in:
Michael Melesse
2023-09-28 15:03:58 -05:00
committed by Jason Furmanek
parent 8ccc4b0cce
commit 28c571ea43

View File

@@ -2142,27 +2142,43 @@ def test_call():
# test if
# -------------
# TODO(Keren): if_exp_dynamic
@pytest.mark.parametrize("if_type", ["if", "if_exp"])
def test_if(if_type):
@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_and_static"])
def test_if(if_type, device):
@triton.jit
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr):
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr):
pid = tl.program_id(0)
cond = tl.load(Cond)
if IfType == "if":
if pid % 2:
if pid % 2 == 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
elif IfType == "if_exp_dynamic":
tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse))
elif IfType == "if_exp_static":
tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse))
elif IfType == "if_and_dynamic":
if BoolVar and pid % 2 == 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
elif IfType == "if_and_static":
if StaticVaue != 0 and StaticVaue != 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
else:
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
cond = torch.ones(1, dtype=torch.int32, device='cuda')
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
ret = torch.empty(1, dtype=torch.float32, device='cuda')
kernel[(1,)](cond, x_true, x_false, ret, if_type)
cond = torch.ones(1, dtype=torch.int32, device=device)
x_true = torch.tensor([3.14], dtype=torch.float32, device=device)
x_false = torch.tensor([1.51], dtype=torch.float32, device=device)
ret = torch.zeros(1, dtype=torch.float32, device=device)
kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1)
assert torch.equal(ret, x_true)
def test_num_warps_pow2():