[FRONTEND] Differentiate between bool and int in the frontend (#1678)

`bool` is a subclass of `int`, so `isinstance(bool_var, int) == True`,
and a `bool` constant will be converted to an `int` constant.

In triton specifically, if a bool var is treated as an integer, it
prevents us using the `logical_and` operator which requires both
operands have the same bit length.

> Cannot bitcast data-type of size 32 to data-type of size 1

By differentiating int and bool, it allows us to make the syntax more
close to native python. We can now use `if bool_var and condition` to
check the truthiness, and `if bool_var is True` to check identity.
This commit is contained in:
Keren Zhou
2023-05-16 14:24:16 -04:00
committed by GitHub
parent 177b46b9ef
commit 3baab48eaf
3 changed files with 13 additions and 7 deletions

View File

@@ -2381,26 +2381,32 @@ def test_call(type):
# -------------
@pytest.mark.parametrize("if_type", ["if", "if_exp"])
@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and"])
def test_if(if_type):
@triton.jit
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr):
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr):
pid = tl.program_id(0)
cond = tl.load(Cond)
if IfType == "if":
if pid % 2:
if pid % 2 == 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
else:
elif IfType == "if_exp":
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
elif IfType == "if_and":
if BoolVar and pid % 2 == 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
cond = torch.ones(1, dtype=torch.int32, device='cuda')
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
ret = torch.empty(1, dtype=torch.float32, device='cuda')
kernel[(1,)](cond, x_true, x_false, ret, if_type)
kernel[(1,)](cond, x_true, x_false, ret, if_type, True)
assert torch.equal(ret, x_true)
def test_num_warps_pow2():