From 9115ce88600636e128b90abb6e1edd09f3988f9f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 26 Mar 2025 18:28:15 +0800 Subject: [PATCH] linearizer fixups from DSP branch (#9581) --- test/external/external_test_example.py | 2 +- tinygrad/codegen/linearize.py | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/test/external/external_test_example.py b/test/external/external_test_example.py index 1dc0b7b547..9c170819d4 100644 --- a/test/external/external_test_example.py +++ b/test/external/external_test_example.py @@ -7,7 +7,7 @@ def multidevice_test(fxn): exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",") def ret(self): for device in Device._devices: - if device in ["DISK", "NPY", "FAKE", "DSP"]: continue + if device in ["DISK", "NPY", "FAKE", "DSP", "NULL"]: continue if not CI: print(device) if device in exclude_devices: if not CI: print(f"WARNING: {device} test is excluded") diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index 5b3d3a42b5..413b391b2e 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -112,7 +112,8 @@ def block_merge(ctx, x:UOp): # keep it in srcs new_srcs.append(u) if len(to_append) == 0 and len(placed) == 0: return None - return UOp(x.op, dtypes.void, tuple(new_srcs), BasicBlock(tuple(sorted(new_ctx, key=lambda x: x.tuplize)), tuple(to_append)+x.arg.lst, x.arg.end)) + return UOp(x.op, dtypes.void, tuple(new_srcs), + BasicBlock(tuple(dedup(sorted(new_ctx, key=lambda x: x.tuplize))), tuple(to_append)+x.arg.lst, x.arg.end)) pm_block_merge = PatternMatcher([(UPat((Ops.BLOCKEND, Ops.BLOCK), name="x"), block_merge),]) @@ -165,6 +166,11 @@ def block_reorder(in_block:UOp): assert len(newlst) == len(in_block.arg.lst), f"len mismatch {len(newlst)} != {len(in_block.arg.lst)}" return in_block.replace(arg=BasicBlock(in_block.arg.ctx, tuple(newlst))) +def upsettingly_promote_blockend(be:UOp): + new_srcs = tuple(b.replace(arg=BasicBlock(be.arg.ctx, b.arg.lst)) if b.op is Ops.BLOCK else b for b in be.src) + return be.replace(src=new_srcs) if be.src != new_srcs else None +pm_force_upcast_block = PatternMatcher([(UPat(Ops.BLOCKEND, name="be"), upsettingly_promote_blockend)]) + def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]: assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" @@ -211,7 +217,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]: if not len(forks): break sink = sink.substitute(forks) - # combine matching BLOCKENDS + # combine matching BLOCKENDS, the keys of this dictionary are the RANGE UOps, values are the BLOCKENDs blockends_to_arg: dict[UOp, list[UOp]] = {} for be in sink.toposort: if be.op is Ops.BLOCKEND: blockends_to_arg.setdefault(be.arg.end, []).append(be) @@ -230,6 +236,10 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> list[UOp]: # final rewrite to merge all blocks into one sink = graph_rewrite(sink, pm_block_merge, ctx=children) + # if there's BLOCKENDs left in the graph, we might have to merge. TODO: is there a better way to handle this? + while (newsink := graph_rewrite(sink, pm_force_upcast_block)) is not sink: + sink = graph_rewrite(newsink, pm_block_merge, ctx=children, name="bad_merge") + # there should just be one block left, with a few parents with 0 srcs (now done in a rewriter) sink = graph_rewrite(sink, pm_block_finalize)