mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: Fix test_if
This commit is contained in:
committed by
Jason Furmanek
parent
8ccc4b0cce
commit
28c571ea43
@@ -2142,27 +2142,43 @@ def test_call():
|
||||
# test if
|
||||
# -------------
|
||||
|
||||
# TODO(Keren): if_exp_dynamic
|
||||
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_exp"])
|
||||
def test_if(if_type):
|
||||
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_and_static"])
|
||||
def test_if(if_type, device):
|
||||
|
||||
@triton.jit
|
||||
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr):
|
||||
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
cond = tl.load(Cond)
|
||||
if IfType == "if":
|
||||
if pid % 2:
|
||||
if pid % 2 == 0:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_exp_dynamic":
|
||||
tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_exp_static":
|
||||
tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_and_dynamic":
|
||||
if BoolVar and pid % 2 == 0:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_and_static":
|
||||
if StaticVaue != 0 and StaticVaue != 0:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
|
||||
|
||||
cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
||||
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](cond, x_true, x_false, ret, if_type)
|
||||
cond = torch.ones(1, dtype=torch.int32, device=device)
|
||||
x_true = torch.tensor([3.14], dtype=torch.float32, device=device)
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device=device)
|
||||
ret = torch.zeros(1, dtype=torch.float32, device=device)
|
||||
|
||||
kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1)
|
||||
assert torch.equal(ret, x_true)
|
||||
|
||||
|
||||
def test_num_warps_pow2():
|
||||
|
||||
Reference in New Issue
Block a user