[FRONTEND] fix AST IR generation for while loop nested inside other SCF (#1947)

The process of visiting twice the body of the while didn't restore
properly the insertion point, and was leaking the dummy block.
This commit is contained in:
Mehdi Amini
2023-07-15 10:17:29 -07:00
committed by GitHub
parent 8207eabd7b
commit 51fc42a568
2 changed files with 18 additions and 1 deletions

View File

@@ -3082,6 +3082,21 @@ def test_while(device):
assert out_i[0] == init_i[0] + 1
assert out_j[0] == bound[0]
def test_while(device):
@triton.jit
def nested_while(data, countPtr):
for i in range(10):
count = tl.load(countPtr)
while count > 0:
tl.store(data, tl.load(data) + 1.0)
count = count - 2
counter = torch.tensor([8], dtype=torch.int32, device=device)
data = torch.zeros((1,), device=device, dtype=torch.float32)
nested_while[(1,)](data, counter)
assert data[0] == 40
# def test_for_if(device):
# @triton.jit

View File

@@ -659,6 +659,7 @@ class CodeGenerator(ast.NodeVisitor):
def visit_While(self, node):
with enter_sub_region(self) as sr:
liveins, insert_block = sr
ip, last_loc = self._get_insertion_point_and_loc()
# loop body (the after region)
# loop_block = self.builder.create_block()
@@ -668,6 +669,7 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_compound_statement(node.body)
self.scf_stack.pop()
loop_defs = self.local_defs
dummy.erase()
# collect loop-carried values
names = []
@@ -684,7 +686,7 @@ class CodeGenerator(ast.NodeVisitor):
ret_types.append(loop_defs[name].type)
init_args.append(liveins[name])
self.builder.set_insertion_point_to_end(insert_block)
self._set_insertion_point_and_loc(ip, last_loc)
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
[arg.handle for arg in init_args])
# merge the condition region