From 5954a0975fefc690be04c15ac7a0be3f32189075 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 21 Aug 2025 15:15:54 -0700 Subject: [PATCH] fix some assigns on rangeify (#11774) * fix some assigns * llvm test * more tests * upd test --- .github/workflows/test.yml | 13 ++++++-- .../external_test_specific_conv.py} | 0 tinygrad/schedule/rangeify.py | 31 ++++++++++++------- 3 files changed, 30 insertions(+), 14 deletions(-) rename test/{test_specific_conv.py => speed/external_test_specific_conv.py} (100%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3a29f56859..107c9d4480 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -601,13 +601,22 @@ jobs: - name: Setup Environment uses: ./.github/actions/setup-tinygrad with: - key: rangeify-minimal + key: rangeify-minimal-llvm deps: testing_minimal + llvm: "true" - name: Test CPU=1 RANGEIFY=1 # TODO: add more passing tests here - run: CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20 + # test_symbolic_arange_sym_step is passing now + # test_threefry_doesnt_use_long is because there's a contig after the long now + run: | + CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \ + -k "not test_symbolic_arange_sym_step and not test_threefry_doesnt_use_long" \ + test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_tensor_variable.py \ + test/test_outerworld_range.py test/test_sample.py test/test_randomness.py test/test_tensor_data.py - name: Test CPU=1 RANGEIFY=1 PARTIAL_CONTIG=1 run: PARTIAL_CONTIG=1 CPU=1 RANGEIFY=1 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20 + - name: Test LLVM=1 RANGEIFY=1 (slow tests) + run: LLVM=1 RANGEIFY=1 python3 -m pytest -n auto test/models/test_mnist.py --durations 20 testdevectorize: name: Linux (devectorize) diff --git a/test/test_specific_conv.py b/test/speed/external_test_specific_conv.py similarity index 100% rename from test/test_specific_conv.py rename to test/speed/external_test_specific_conv.py diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index f80efd03b6..13408c3043 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -10,7 +10,12 @@ from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, K # 0. do some cleanup rewrites, mostly copied from the old stuff -earliest_rewrites = PatternMatcher([ +double_reshape = PatternMatcher([ + # RESHAPE on RESHAPE is the second reshape + (UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))), +]) + +earliest_rewrites = double_reshape+PatternMatcher([ # UOp with size 0 is zero (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None), # DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE @@ -18,8 +23,6 @@ earliest_rewrites = PatternMatcher([ # reduce of size 0 is the identity element (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), - # RESHAPE on RESHAPE is the second reshape - (UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"), lambda x: x.replace(src=(x.src[0].src[0],))), # non shape changing RESHAPE is NOOP (UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None), # RESHAPE after COPY @@ -49,6 +52,9 @@ def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None: for s in rb.src: if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None +def realize_assign(ctx:dict[UOp, None], a:UOp) -> None: + if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None + do_realize = PatternMatcher([ # always realize SINK parents (UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)), @@ -56,6 +62,8 @@ do_realize = PatternMatcher([ (UPat({Ops.ASSIGN, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize), # realize parents of COPY, MSELECT, MSTACK (UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents), + # realize input to assign (might be optimized out) + (UPat(Ops.ASSIGN, name="a"), realize_assign), ]) add_contiguous = PatternMatcher([ @@ -153,18 +161,18 @@ def map_expand(r:UOp, idx:UOp): pm_mops = PatternMatcher([ # this is like the definitions of these - (UPat(Ops.INDEX, src=(UPat(Ops.SHRINK, name="r"),), allow_any_len=True, name="idx"), + (UPat(Ops.SHRINK, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), lambda r,idx: r.src[0].index(*[a+ss if resolve(ss != 0) else a for a,(ss,_) in zip(idx.src[1:], r.arg)], dtype=idx.dtype, arg=idx.arg)), - (UPat(Ops.INDEX, src=(UPat(Ops.PERMUTE, name="r"),), allow_any_len=True, name="idx"), + (UPat(Ops.PERMUTE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), lambda r,idx: r.src[0].index(*[idx.src[1+p] for p in argsort(idx.src[0].arg)], dtype=idx.dtype, arg=idx.arg)), - (UPat(Ops.INDEX, src=(UPat(Ops.FLIP, name="r"),), allow_any_len=True, name="idx"), + (UPat(Ops.FLIP, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), lambda r,idx: r.src[0].index(*[((s-1)-a) if f else a for a,s,f in zip(idx.src[1:], r.shape, r.arg)], dtype=idx.dtype, arg=idx.arg)), # expand needs to end ranges - (UPat(Ops.INDEX, src=(UPat(Ops.EXPAND, name="r"),), allow_any_len=True, name="idx"), map_expand), + (UPat(Ops.EXPAND, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_expand), # reshape does a lot of symbolic stuff - (UPat(Ops.INDEX, src=(UPat(Ops.RESHAPE, name="r"),), allow_any_len=True, name="idx"), map_reshape), + (UPat(Ops.RESHAPE, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_reshape), # pad adds min and max - (UPat(Ops.INDEX, src=(UPat(Ops.PAD, name="r"),), allow_any_len=True, name="idx"), map_pad), + (UPat(Ops.PAD, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_pad), ]) def map_partial_contiguous(ctx:RangeifyContext, x:UOp, idx:UOp): @@ -189,8 +197,7 @@ def map_contiguous(ctx:RangeifyContext, x:UOp): ranges = [] for s in x.shape: ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.int, 0)) - ret = x.src[0].index(*ranges).bufferize(*[x for x in ranges if x.op is not Ops.CONST], arg=x.device) - return ret.forced_reshape(x.shape) + return x.src[0].index(*ranges).bufferize(*[x for x in ranges if x.op is not Ops.CONST], arg=x.device).forced_reshape(x.shape) def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp): rngs = list(idx.src[1:]) @@ -304,7 +311,7 @@ def remove_bufferize(b2:UOp, idx2:UOp): assert all(x.op is Ops.RANGE for x in b2.src[1:]) return b2.src[0].substitute(dict(zip(b2.src[1:], idx2.src[1:]))) -pm_cleanups = pm_mops+PatternMatcher([ +pm_cleanups = double_reshape+pm_mops+PatternMatcher([ #(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes), # remove noop buffers. if we look at the next index we can remove even more of these # NOTE: this is mostly the same case as below, but if there's no INDEX this gets more