mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Fix return op related control flow issues (#1637)
- Case 1: Return after static control flow is taken. Peel off
instructions after the first `return` for each basic block.
```python
if static_condition:
tl.store(...)
return
return
```
- Case 2: Return exists in both `if` and `else` branches of an inlined
`JITFunction` function
```python
def foo():
if dynamic_condition:
return a
else:
return b
```
- Case 3: Return exists in a `JITFunction` from another module
```python
import module
if cond:
a = module.func()
```
- Case 4: A chain of calls through undefined local variables
```python
import module
if cond:
a = x
a = a.to(tl.int32).to(tl.int32)
```
- Case 5: Call a function `func` without returning variables. `func` is
recognized as an `Expr` first instead of a `Call`.
```python
if cond:
foo()
else:
bar()
```
- Case 6: Call a `noinline` function. We don't need to check if the
function contains any return op.
This commit is contained in:
@@ -2550,24 +2550,30 @@ def test_if_else():
|
||||
assert to_numpy(out)[0] == false_val[0]
|
||||
|
||||
|
||||
def test_if_return():
|
||||
@pytest.mark.parametrize("mode", ["dynamic", "static"])
|
||||
def test_if_return(mode):
|
||||
|
||||
@triton.jit
|
||||
def kernel(ExitEarly, Out):
|
||||
if tl.load(ExitEarly):
|
||||
tl.store(Out, 0)
|
||||
return
|
||||
def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr):
|
||||
if mode == "dynamic":
|
||||
if tl.load(ExitEarly):
|
||||
tl.store(Out, 0)
|
||||
return
|
||||
else:
|
||||
if cond:
|
||||
tl.store(Out, 0)
|
||||
return
|
||||
tl.store(Out, 1)
|
||||
|
||||
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
||||
exit_early = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
||||
# exit early path taken
|
||||
exit_early[0] = 1
|
||||
kernel[(1,)](exit_early, out)
|
||||
kernel[(1,)](exit_early, out, True, mode)
|
||||
assert to_numpy(out)[0] == 0
|
||||
# exit early path not taken
|
||||
exit_early[0] = 0
|
||||
kernel[(1,)](exit_early, out)
|
||||
kernel[(1,)](exit_early, out, False, mode)
|
||||
assert to_numpy(out)[0] == 1
|
||||
|
||||
|
||||
@@ -2576,7 +2582,34 @@ def add_fn(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("call_type", ["attribute", "jit_function"])
|
||||
@triton.jit(noinline=True)
|
||||
def add_fn_noinline(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_fn_return(x, pid):
|
||||
if pid == 0:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_fn_expr(Out, x):
|
||||
tl.store(Out, x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_fn_static_cond(x, cond: tl.constexpr):
|
||||
if cond == "":
|
||||
return x
|
||||
else:
|
||||
return x + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return",
|
||||
"ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"])
|
||||
def test_if_call(call_type):
|
||||
@triton.jit
|
||||
def kernel(Out, call_type: tl.constexpr):
|
||||
@@ -2584,13 +2617,34 @@ def test_if_call(call_type):
|
||||
o = tl.load(Out)
|
||||
if pid == 0:
|
||||
if call_type == "attribute":
|
||||
# call attribute
|
||||
a = o + 1
|
||||
a = a.to(tl.int32)
|
||||
a = a.to(tl.int32).to(tl.int32)
|
||||
o = a
|
||||
else:
|
||||
a = o
|
||||
a = add_fn(a)
|
||||
if call_type == "jit_function":
|
||||
# regular function call
|
||||
a = add_fn(a)
|
||||
elif call_type == "jit_function_return":
|
||||
# function without end_if block
|
||||
a = add_fn_return(a, pid)
|
||||
elif call_type == "ifexp":
|
||||
# ifexp expression
|
||||
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
|
||||
elif call_type == "expr":
|
||||
if pid == 1:
|
||||
return
|
||||
a = add_fn(a)
|
||||
if pid == 0:
|
||||
# call without return
|
||||
add_fn_expr(Out, a)
|
||||
elif call_type == "jit_function_static_cond":
|
||||
a = add_fn_static_cond(a, call_type)
|
||||
elif call_type == "jit_function_noinline":
|
||||
a = add_fn_noinline(a)
|
||||
o = a
|
||||
|
||||
tl.store(Out, o)
|
||||
|
||||
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
||||
|
||||
Reference in New Issue
Block a user