rangeify: late zero folding (#12464)

* rangeify: late zero folding

* early

* not kernels

* none

* multi

* linter

* mstack is sink comment

* more comment
This commit is contained in:
qazal
2025-10-06 12:52:33 +03:00
committed by GitHub
parent 0c015a24fe
commit 76e8a3250c
3 changed files with 15 additions and 7 deletions

View File

@@ -7,7 +7,7 @@ from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
import numpy as np
from hypothesis import given, strategies as strat, settings
from test.helpers import REAL_DEV, not_support_multi_device, expect_rangeify_fails
from test.helpers import REAL_DEV, not_support_multi_device
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
@@ -201,7 +201,6 @@ class TestMultiTensor(unittest.TestCase):
fn = f(n)
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
@expect_rangeify_fails # TODO: fix
def test_allreduce_shard_ring_sum(self):
for axis in (0, 1, None):
for use_ring in (0, 2):

View File

@@ -814,6 +814,14 @@ class TestSchedule(unittest.TestCase):
check_schedule(a, 0)
self.assertEqual(a.tolist(), [])
def test_zero_size_children(self):
r = Tensor.ones(1,2).contiguous().realize().sum(axis=(1,), keepdim=True)
ax = r.reshape(1)*2
ay = r.reshape(1).shrink(((1,1),))*2
out = ax+ay.pad(((1, 0),))
run_schedule(check_schedule(out, 1))
self.assertEqual(out.item(), 4.)
def test_reduce_permute_nofuse(self):
x = Tensor.empty(32, 32, 32)
y = Tensor.empty(32, 32)

View File

@@ -57,6 +57,9 @@ earliest_rewrites = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# handle size 0
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x.st is not None and x.size == 0 else None),
# remove contiguous on movement ops before a copy on disk
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"),
lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
@@ -133,7 +136,8 @@ def extract_children(ctx:ChildrenContext, x:UOp):
children_map = x.get_children_map()
ctx.children = {}
for k,v in children_map.items():
non_sink_children = [u for u in v if u.op is not Ops.SINK]
# NOTE: we treat mstack children like sink here
non_sink_children = [u for u in v if u.op not in {Ops.SINK, Ops.MSTACK}]
if len(non_sink_children) <= 1: continue
# NOTE: this gate shouldn't be here
if k.op_in_parents(Ops.REDUCE_AXIS) and k.op_in_parents(Ops.BUFFER, Ops.CONTIGUOUS):
@@ -363,9 +367,6 @@ pm_rangeify = pm_mops+PatternMatcher([
# handle arg on any op with weight. old endrange stuff
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
# handle size 0
(UPat(Ops.INDEX, name="x"), lambda x: x.replace(src=(x.const_like(0),)+x.src[1:]) if x.st is not None and x.size == 0 else None),
# handle assign
(UPat(Ops.INDEX, src=(UPat(Ops.ASSIGN, name="assign"),), allow_any_len=True, name="x"),
lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],)) \
@@ -572,7 +573,7 @@ pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([
# move RESHAPEs through MSELECT/MSTACK
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
lambda m: m.replace(src=tuple([x.src[0] for x in m.src]), tag=None).reshape(m.src[0].arg).rtag(m.tag)),
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.src[0].arg).rtag(m.tag)),
])
# *****************