[FRONTEND] Implement ternary operator for dynamic values (#2560)

This commit is contained in:
Vedant Roy
2023-11-01 17:21:32 -07:00
committed by GitHub
parent e7fdfd76fb
commit 702cde0d6f
2 changed files with 43 additions and 7 deletions

View File

@@ -3169,12 +3169,9 @@ def test_call(type, num_ctas, device):
# test if
# -------------
# TODO(Keren): if_exp_dynamic
@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_and_static"])
@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", "if_and_static"])
def test_if(if_type, device):
@triton.jit
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr):
pid = tl.program_id(0)
@@ -3185,6 +3182,12 @@ def test_if(if_type, device):
elif 1 == pid % 2: # req
tl.store(Ret, tl.load(XFalse))
elif IfType == "if_exp_dynamic":
val = tl.load(XTrue) if pid % 2 == 0 else tl.load(XFalse)
tl.store(Ret, val)
elif IfType == "if_exp_dynamic_constexpr":
val = 3.14 if pid % 2 == 0 else tl.load(XFalse)
tl.store(Ret, val)
elif IfType == "if_exp_dynamic_void":
tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse))
elif IfType == "if_exp_static":
tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse))

View File

@@ -624,9 +624,42 @@ class CodeGenerator(ast.NodeVisitor):
def visit_IfExp(self, node):
cond = self.visit(node.test)
if _is_triton_tensor(cond):
raise UnsupportedLanguageConstruct(
None, node,
"Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead")
cond = cond.to(language.int1, _builder=self.builder)
# TODO: Deal w/ more complicated return types (e.g tuple)
with enter_sub_region(self):
ip, last_loc = self._get_insertion_point_and_loc()
then_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(then_block)
then_val = language.core._to_tensor(self.visit(node.body), self.builder)
then_block = self.builder.get_insertion_block()
else_block = self.builder.create_block()
self.builder.set_insertion_point_to_start(else_block)
# do not need to reset lscope since
# ternary expressions cannot define new variables
else_val = language.core._to_tensor(self.visit(node.orelse), self.builder)
else_block = self.builder.get_insertion_block()
self._set_insertion_point_and_loc(ip, last_loc)
assert then_val.type == else_val.type, \
f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
ret_type = then_val.type
ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
then_block.merge_block_before(if_op.get_then_block())
if ret_type_ir:
self.builder.set_insertion_point_to_end(if_op.get_then_block())
self.builder.create_yield_op([then_val.handle])
self.builder.set_insertion_point_to_end(if_op.get_then_block())
else_block.merge_block_before(if_op.get_else_block())
if ret_type_ir:
self.builder.set_insertion_point_to_end(if_op.get_else_block())
self.builder.create_yield_op([else_val.handle])
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
else:
cond = _unwrap_if_constexpr(cond)
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks