[FRONTEND] Fix return op related control flow issues (#1637)

- Case 1: Return after static control flow is taken. Peel off
instructions after the first `return` for each basic block.

```python
if static_condition:
    tl.store(...)
    return
return
```

- Case 2: Return exists in both `if` and `else` branches of an inlined
`JITFunction` function

```python
def foo():
    if dynamic_condition:
        return a
    else:
        return b
```

- Case 3: Return exists in a `JITFunction` from another module

```python
import module
if cond:
    a = module.func()
```

- Case 4: A chain of calls through undefined local variables

```python
import module
if cond:
    a = x
    a = a.to(tl.int32).to(tl.int32)
```

- Case 5: Call a function `func` without returning variables. `func` is
recognized as an `Expr` first instead of a `Call`.

```python
if cond:
    foo()
else:
    bar()
```

- Case 6: Call a `noinline` function. We don't need to check if the
function contains any return op.
This commit is contained in:
Keren Zhou
2023-05-09 12:51:14 -04:00
committed by GitHub
parent 319af1fb65
commit b19b274d93
3 changed files with 140 additions and 29 deletions

View File

@@ -2550,24 +2550,30 @@ def test_if_else():
assert to_numpy(out)[0] == false_val[0]
def test_if_return():
@pytest.mark.parametrize("mode", ["dynamic", "static"])
def test_if_return(mode):
@triton.jit
def kernel(ExitEarly, Out):
if tl.load(ExitEarly):
tl.store(Out, 0)
return
def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr):
if mode == "dynamic":
if tl.load(ExitEarly):
tl.store(Out, 0)
return
else:
if cond:
tl.store(Out, 0)
return
tl.store(Out, 1)
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
exit_early = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
# exit early path taken
exit_early[0] = 1
kernel[(1,)](exit_early, out)
kernel[(1,)](exit_early, out, True, mode)
assert to_numpy(out)[0] == 0
# exit early path not taken
exit_early[0] = 0
kernel[(1,)](exit_early, out)
kernel[(1,)](exit_early, out, False, mode)
assert to_numpy(out)[0] == 1
@@ -2576,7 +2582,34 @@ def add_fn(x):
return x + 1
@pytest.mark.parametrize("call_type", ["attribute", "jit_function"])
@triton.jit(noinline=True)
def add_fn_noinline(x):
return x + 1
@triton.jit
def add_fn_return(x, pid):
if pid == 0:
return x + 1
else:
return x + 2
@triton.jit
def add_fn_expr(Out, x):
tl.store(Out, x)
@triton.jit
def add_fn_static_cond(x, cond: tl.constexpr):
if cond == "":
return x
else:
return x + 1
@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return",
"ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"])
def test_if_call(call_type):
@triton.jit
def kernel(Out, call_type: tl.constexpr):
@@ -2584,13 +2617,34 @@ def test_if_call(call_type):
o = tl.load(Out)
if pid == 0:
if call_type == "attribute":
# call attribute
a = o + 1
a = a.to(tl.int32)
a = a.to(tl.int32).to(tl.int32)
o = a
else:
a = o
a = add_fn(a)
if call_type == "jit_function":
# regular function call
a = add_fn(a)
elif call_type == "jit_function_return":
# function without end_if block
a = add_fn_return(a, pid)
elif call_type == "ifexp":
# ifexp expression
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
elif call_type == "expr":
if pid == 1:
return
a = add_fn(a)
if pid == 0:
# call without return
add_fn_expr(Out, a)
elif call_type == "jit_function_static_cond":
a = add_fn_static_cond(a, call_type)
elif call_type == "jit_function_noinline":
a = add_fn_noinline(a)
o = a
tl.store(Out, o)
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')