mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Disable IfExp on dynamic conditions (#2100)
`if _unwrap_if_constexpr(cond)` then enters `node.body` is wrong when cond is a tensor since we cannot statically evaluate a dynamic tensor's value. The right way to solve the problem is probably: 1. visit the ast of IfExp (do not build IRs) 2. get the type of the last statement 3. initialize the return value and assign it to livein 4. call visit_If
This commit is contained in:
@@ -2899,8 +2899,8 @@ def test_call(type, num_ctas, device):
|
||||
# test if
|
||||
# -------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and_dynamic", "if_and_static"])
|
||||
# TODO(Keren): if_exp_dynamic
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_and_static"])
|
||||
def test_if(if_type, device):
|
||||
|
||||
@triton.jit
|
||||
@@ -2912,8 +2912,10 @@ def test_if(if_type, device):
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_exp":
|
||||
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_exp_dynamic":
|
||||
tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_exp_static":
|
||||
tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_and_dynamic":
|
||||
if BoolVar and pid % 2 == 0:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
@@ -2928,7 +2930,7 @@ def test_if(if_type, device):
|
||||
cond = torch.ones(1, dtype=torch.int32, device=device)
|
||||
x_true = torch.tensor([3.14], dtype=torch.float32, device=device)
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device=device)
|
||||
ret = torch.empty(1, dtype=torch.float32, device=device)
|
||||
ret = torch.zeros(1, dtype=torch.float32, device=device)
|
||||
|
||||
kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1)
|
||||
assert torch.equal(ret, x_true)
|
||||
@@ -3209,8 +3211,9 @@ def add_fn_static_cond(x, cond: tl.constexpr):
|
||||
return x + 1
|
||||
|
||||
|
||||
# TODO(Keren): if_exp
|
||||
@pytest.mark.parametrize("call_type", ["attribute", "attribute_jit",
|
||||
"jit", "jit_if", "jit_ifexp", "jit_expr",
|
||||
"jit", "jit_if", "jit_expr",
|
||||
"jit_static_cond", "jit_noinline", "jit_extern"])
|
||||
def test_if_call(call_type, device):
|
||||
@triton.jit
|
||||
@@ -3241,7 +3244,7 @@ def test_if_call(call_type, device):
|
||||
a = o
|
||||
a = add_fn_return(a, pid)
|
||||
o = a
|
||||
elif call_type == "jit_ifexp":
|
||||
elif call_type == "jit_if_exp":
|
||||
# ifexp expression
|
||||
if pid == 0:
|
||||
a = o
|
||||
|
||||
@@ -64,6 +64,10 @@ def _is_triton_scalar(o: Any) -> bool:
|
||||
return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1)
|
||||
|
||||
|
||||
def _is_list_like(o: Any) -> bool:
|
||||
return isinstance(o, (list, tuple))
|
||||
|
||||
|
||||
def _unwrap_if_constexpr(o: Any):
|
||||
return o.value if isinstance(o, constexpr) else o
|
||||
|
||||
@@ -284,6 +288,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# AST visitor
|
||||
#
|
||||
def visit_compound_statement(self, stmts):
|
||||
# Ensure that stmts is iterable
|
||||
if not _is_list_like(stmts):
|
||||
stmts = [stmts]
|
||||
for stmt in stmts:
|
||||
ret_type = self.visit(stmt)
|
||||
if ret_type is not None and isinstance(stmt, ast.Return):
|
||||
@@ -413,9 +420,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple assignment is not supported.")
|
||||
names = _names[0]
|
||||
values = self.visit(node.value)
|
||||
if not isinstance(names, tuple):
|
||||
if not _is_list_like(names):
|
||||
names = [names]
|
||||
if not isinstance(values, tuple):
|
||||
if not _is_list_like(values):
|
||||
values = [values]
|
||||
native_nontensor_types = (language.dtype, )
|
||||
for name, value in zip(names, values):
|
||||
@@ -619,11 +626,19 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_IfExp(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if _is_triton_tensor(cond):
|
||||
cond = cond.to(language.int1, _builder=self.builder)
|
||||
if _unwrap_if_constexpr(cond):
|
||||
return self.visit(node.body)
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead")
|
||||
else:
|
||||
return self.visit(node.orelse)
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format(
|
||||
', '.join(_.__name__ for _ in _condition_types), type(cond).__name__))
|
||||
if cond:
|
||||
return self.visit(node.body)
|
||||
else:
|
||||
return self.visit(node.orelse)
|
||||
|
||||
def visit_Pass(self, node):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user