mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Differentiate between bool and int in the frontend (#1678)
`bool` is a subclass of `int`, so `isinstance(bool_var, int) == True`, and a `bool` constant will be converted to an `int` constant. In triton specifically, if a bool var is treated as an integer, it prevents us using the `logical_and` operator which requires both operands have the same bit length. > Cannot bitcast data-type of size 32 to data-type of size 1 By differentiating int and bool, it allows us to make the syntax more close to native python. We can now use `if bool_var and condition` to check the truthiness, and `if bool_var is True` to check identity.
This commit is contained in:
@@ -2381,26 +2381,32 @@ def test_call(type):
|
||||
# -------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_exp"])
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and"])
|
||||
def test_if(if_type):
|
||||
|
||||
@triton.jit
|
||||
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr):
|
||||
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
cond = tl.load(Cond)
|
||||
if IfType == "if":
|
||||
if pid % 2:
|
||||
if pid % 2 == 0:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
else:
|
||||
elif IfType == "if_exp":
|
||||
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_and":
|
||||
if BoolVar and pid % 2 == 0:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
|
||||
cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
||||
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](cond, x_true, x_false, ret, if_type)
|
||||
kernel[(1,)](cond, x_true, x_false, ret, if_type, True)
|
||||
assert torch.equal(ret, x_true)
|
||||
|
||||
|
||||
def test_num_warps_pow2():
|
||||
|
||||
Reference in New Issue
Block a user