mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
minor linearizer refactor to finalize in rewrite [pr] (#9148)
This commit is contained in:
@@ -112,6 +112,16 @@ def block_merge(ctx, x:UOp):
|
||||
|
||||
pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),])
|
||||
|
||||
def block_finalize(block:UOp):
|
||||
if len(block.src) == 0: return None
|
||||
_uops = sorted(dedup(block.src), key=lambda x: x.tuplize)
|
||||
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
|
||||
_uops += block.arg.lst
|
||||
assert _uops[-1].op is Ops.SINK, "block doesn't end with SINK"
|
||||
return UOp(Ops.BLOCK, arg=BasicBlock((), tuple(_uops[:-1])))
|
||||
|
||||
pm_block_finalize = PatternMatcher([(UPat(Ops.BLOCK, name="block"), block_finalize)])
|
||||
|
||||
# NOTE: any toposort should be valid here, unlike last time this isn't required, it's just for speed
|
||||
def block_reorder(in_block:UOp):
|
||||
in_this_block = set(in_block.arg.lst)
|
||||
@@ -212,14 +222,11 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]:
|
||||
# final rewrite to merge all blocks into one
|
||||
sink = graph_rewrite(sink, pm_block_merge, ctx=children)
|
||||
|
||||
# there should just be one block left, with a few parents with 0 srcs
|
||||
assert sink.op is Ops.BLOCK
|
||||
_uops = sorted(dedup(sink.src), key=lambda x: x.tuplize)
|
||||
assert all(len(x.src) == 0 and x.op not in {Ops.BLOCK, Ops.BLOCKSTART, Ops.BLOCKEND, Ops.BLOCKFORK} for x in _uops)
|
||||
_uops += sink.arg.lst
|
||||
# there should just be one block left, with a few parents with 0 srcs (now done in a rewriter)
|
||||
sink = graph_rewrite(sink, pm_block_finalize)
|
||||
|
||||
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
||||
if not skip_check: type_verify(_uops)
|
||||
if not skip_check: type_verify(sink.arg.lst)
|
||||
|
||||
# strip the SINK
|
||||
return _uops[:-1]
|
||||
# return the list. TODO: refactor to return the UOp
|
||||
return list(sink.arg.lst)
|
||||
|
||||
Reference in New Issue
Block a user