mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] fix AST IR generation for while loop nested inside other SCF (#1947)
The process of visiting twice the body of the while didn't restore properly the insertion point, and was leaking the dummy block.
This commit is contained in:
@@ -3082,6 +3082,21 @@ def test_while(device):
|
||||
assert out_i[0] == init_i[0] + 1
|
||||
assert out_j[0] == bound[0]
|
||||
|
||||
def test_while(device):
|
||||
@triton.jit
|
||||
def nested_while(data, countPtr):
|
||||
for i in range(10):
|
||||
count = tl.load(countPtr)
|
||||
while count > 0:
|
||||
tl.store(data, tl.load(data) + 1.0)
|
||||
count = count - 2
|
||||
|
||||
counter = torch.tensor([8], dtype=torch.int32, device=device)
|
||||
data = torch.zeros((1,), device=device, dtype=torch.float32)
|
||||
nested_while[(1,)](data, counter)
|
||||
assert data[0] == 40
|
||||
|
||||
|
||||
# def test_for_if(device):
|
||||
|
||||
# @triton.jit
|
||||
|
||||
@@ -659,6 +659,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_While(self, node):
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
ip, last_loc = self._get_insertion_point_and_loc()
|
||||
|
||||
# loop body (the after region)
|
||||
# loop_block = self.builder.create_block()
|
||||
@@ -668,6 +669,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_compound_statement(node.body)
|
||||
self.scf_stack.pop()
|
||||
loop_defs = self.local_defs
|
||||
dummy.erase()
|
||||
|
||||
# collect loop-carried values
|
||||
names = []
|
||||
@@ -684,7 +686,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ret_types.append(loop_defs[name].type)
|
||||
init_args.append(liveins[name])
|
||||
|
||||
self.builder.set_insertion_point_to_end(insert_block)
|
||||
self._set_insertion_point_and_loc(ip, last_loc)
|
||||
while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types],
|
||||
[arg.handle for arg in init_args])
|
||||
# merge the condition region
|
||||
|
||||
Reference in New Issue
Block a user