From 76e8a3250c585be5cccb1104d227c88695501d7c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 6 Oct 2025 12:52:33 +0300 Subject: [PATCH] rangeify: late zero folding (#12464) * rangeify: late zero folding * early * not kernels * none * multi * linter * mstack is sink comment * more comment --- test/test_multitensor.py | 3 +-- test/test_schedule.py | 8 ++++++++ tinygrad/schedule/rangeify.py | 11 ++++++----- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index d1c24b0d65..5711018454 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -7,7 +7,7 @@ from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule import numpy as np from hypothesis import given, strategies as strat, settings -from test.helpers import REAL_DEV, not_support_multi_device, expect_rangeify_fails +from test.helpers import REAL_DEV, not_support_multi_device settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) settings.load_profile("my_profile") @@ -201,7 +201,6 @@ class TestMultiTensor(unittest.TestCase): fn = f(n) np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6) - @expect_rangeify_fails # TODO: fix def test_allreduce_shard_ring_sum(self): for axis in (0, 1, None): for use_ring in (0, 2): diff --git a/test/test_schedule.py b/test/test_schedule.py index 7815bbb2fb..3e7055c6a3 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -814,6 +814,14 @@ class TestSchedule(unittest.TestCase): check_schedule(a, 0) self.assertEqual(a.tolist(), []) + def test_zero_size_children(self): + r = Tensor.ones(1,2).contiguous().realize().sum(axis=(1,), keepdim=True) + ax = r.reshape(1)*2 + ay = r.reshape(1).shrink(((1,1),))*2 + out = ax+ay.pad(((1, 0),)) + run_schedule(check_schedule(out, 1)) + self.assertEqual(out.item(), 4.) + def test_reduce_permute_nofuse(self): x = Tensor.empty(32, 32, 32) y = Tensor.empty(32, 32) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index c4cd782b18..2e47656d9a 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -57,6 +57,9 @@ earliest_rewrites = PatternMatcher([ (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), + # handle size 0 + (UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x.st is not None and x.size == 0 else None), + # remove contiguous on movement ops before a copy on disk (UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"), lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None), @@ -133,7 +136,8 @@ def extract_children(ctx:ChildrenContext, x:UOp): children_map = x.get_children_map() ctx.children = {} for k,v in children_map.items(): - non_sink_children = [u for u in v if u.op is not Ops.SINK] + # NOTE: we treat mstack children like sink here + non_sink_children = [u for u in v if u.op not in {Ops.SINK, Ops.MSTACK}] if len(non_sink_children) <= 1: continue # NOTE: this gate shouldn't be here if k.op_in_parents(Ops.REDUCE_AXIS) and k.op_in_parents(Ops.BUFFER, Ops.CONTIGUOUS): @@ -363,9 +367,6 @@ pm_rangeify = pm_mops+PatternMatcher([ # handle arg on any op with weight. old endrange stuff (UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis), - # handle size 0 - (UPat(Ops.INDEX, name="x"), lambda x: x.replace(src=(x.const_like(0),)+x.src[1:]) if x.st is not None and x.size == 0 else None), - # handle assign (UPat(Ops.INDEX, src=(UPat(Ops.ASSIGN, name="assign"),), allow_any_len=True, name="x"), lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],)) \ @@ -572,7 +573,7 @@ pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([ # move RESHAPEs through MSELECT/MSTACK (UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"), - lambda m: m.replace(src=tuple([x.src[0] for x in m.src]), tag=None).reshape(m.src[0].arg).rtag(m.tag)), + lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.src[0].arg).rtag(m.tag)), ]) # *****************