mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Implement ternary operator for dynamic values (#2560)
This commit is contained in:
@@ -3169,12 +3169,9 @@ def test_call(type, num_ctas, device):
|
||||
# test if
|
||||
# -------------
|
||||
|
||||
# TODO(Keren): if_exp_dynamic
|
||||
|
||||
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_and_static"])
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", "if_and_static"])
|
||||
def test_if(if_type, device):
|
||||
|
||||
@triton.jit
|
||||
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
@@ -3185,6 +3182,12 @@ def test_if(if_type, device):
|
||||
elif 1 == pid % 2: # req
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_exp_dynamic":
|
||||
val = tl.load(XTrue) if pid % 2 == 0 else tl.load(XFalse)
|
||||
tl.store(Ret, val)
|
||||
elif IfType == "if_exp_dynamic_constexpr":
|
||||
val = 3.14 if pid % 2 == 0 else tl.load(XFalse)
|
||||
tl.store(Ret, val)
|
||||
elif IfType == "if_exp_dynamic_void":
|
||||
tl.store(Ret, tl.load(XTrue)) if pid % 2 == 0 else tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_exp_static":
|
||||
tl.store(Ret, tl.load(XTrue)) if BoolVar else tl.store(Ret, tl.load(XFalse))
|
||||
|
||||
@@ -624,9 +624,42 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_IfExp(self, node):
|
||||
cond = self.visit(node.test)
|
||||
if _is_triton_tensor(cond):
|
||||
raise UnsupportedLanguageConstruct(
|
||||
None, node,
|
||||
"Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead")
|
||||
cond = cond.to(language.int1, _builder=self.builder)
|
||||
# TODO: Deal w/ more complicated return types (e.g tuple)
|
||||
with enter_sub_region(self):
|
||||
ip, last_loc = self._get_insertion_point_and_loc()
|
||||
|
||||
then_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(then_block)
|
||||
then_val = language.core._to_tensor(self.visit(node.body), self.builder)
|
||||
then_block = self.builder.get_insertion_block()
|
||||
|
||||
else_block = self.builder.create_block()
|
||||
self.builder.set_insertion_point_to_start(else_block)
|
||||
# do not need to reset lscope since
|
||||
# ternary expressions cannot define new variables
|
||||
else_val = language.core._to_tensor(self.visit(node.orelse), self.builder)
|
||||
else_block = self.builder.get_insertion_block()
|
||||
|
||||
self._set_insertion_point_and_loc(ip, last_loc)
|
||||
|
||||
assert then_val.type == else_val.type, \
|
||||
f'ternary expression with dynamic condition has inconsistent types {then_val.type} and {else_val.type}'
|
||||
ret_type = then_val.type
|
||||
|
||||
ret_type_ir = [ret_type.to_ir(self.builder)] if ret_type != language.void else []
|
||||
if_op = self.builder.create_if_op(ret_type_ir, cond.handle, True)
|
||||
then_block.merge_block_before(if_op.get_then_block())
|
||||
if ret_type_ir:
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
self.builder.create_yield_op([then_val.handle])
|
||||
|
||||
self.builder.set_insertion_point_to_end(if_op.get_then_block())
|
||||
else_block.merge_block_before(if_op.get_else_block())
|
||||
if ret_type_ir:
|
||||
self.builder.set_insertion_point_to_end(if_op.get_else_block())
|
||||
self.builder.create_yield_op([else_val.handle])
|
||||
return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None
|
||||
else:
|
||||
cond = _unwrap_if_constexpr(cond)
|
||||
if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks
|
||||
|
||||
Reference in New Issue
Block a user