diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index f4511500f3..be8027238b 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -112,6 +112,16 @@ def block_merge(ctx, x:UOp): pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),]) +def block_finalize(block:UOp): + if len(block.src) == 0: return None + _uops = sorted(dedup(block.src), key=lambda x: x.tuplize) + assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops) + _uops += block.arg.lst + assert _uops[-1].op is Ops.SINK, "block doesn't end with SINK" + return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1]))) + +pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)]) + # NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed def block_reorder(in_block:UOp): in_this_block = set(in_block.arg.lst) @@ -212,14 +222,11 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]: # final rewrite to merge all blocks into one sink = graph_rewrite(sink, pm_block_merge, ctx=children) - # there should just be one block left, with a few parents with 0 srcs - assert sink.op is Ops.BLOCK - _uops = sorted(dedup(sink.src), key=lambda x: x.tuplize) - assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops) - _uops += sink.arg.lst + # there should just be one block left, with a few parents with 0 srcs (now done in a rewriter) + sink = graph_rewrite(sink, pm_block_finalize) # sanity checks (NOTE: these can cause things to be skipped in BEAM) - if not skip_check: type_verify(_uops) + if not skip_check: type_verify(sink.arg.lst) - # strip the SINK - return _uops[:-1] + # return the list. TODO: refactor to return the UOp + return list(sink.arg.lst)