From d9951e2f8e6468df4afabeb5caafd9df3a2c5352 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Sat, 14 Mar 2026 00:38:48 -0700 Subject: [PATCH] shrink guarded ranges (#15263) --- test/null/test_simplify_valid_idx.py | 58 ++++++++++++++++++++++++++-- test/null/test_uop_graph.py | 12 ++++-- tinygrad/codegen/__init__.py | 2 +- tinygrad/codegen/simplify.py | 30 +++++++++----- 4 files changed, 85 insertions(+), 17 deletions(-) diff --git a/test/null/test_simplify_valid_idx.py b/test/null/test_simplify_valid_idx.py index 8265c4272a..ab0264757c 100644 --- a/test/null/test_simplify_valid_idx.py +++ b/test/null/test_simplify_valid_idx.py @@ -46,9 +46,9 @@ class TestHelpers(unittest.TestCase): self.assertTrue((rng+2).is_increasing()) class TestValidIdxSimplification(unittest.TestCase): - def check(self, load, sidx, svalid): + def check(self, load, sidx, svalid, extra=()): with Context(NOOPT=1, SPEC=0): - load = full_rewrite_to_sink(load.sink()).src[0] + load = full_rewrite_to_sink(UOp.sink(load, *extra)).src[0] idx, valid = load.src[0].src[1], load.src[0].src[2] check_uop_against_string(self, idx, sidx) check_uop_against_string(self, valid, svalid) @@ -156,9 +156,12 @@ class TestValidIdxSimplification(unittest.TestCase): idx = (alu15*-31)+(((((alu11+218)//224)+ridx0)%30)*1568) valid = (ridx2<1)&(ridx1<6) load = get_gated_load_uop(valid, idx) + # prevent ridx1 and ridx2 from being shrunk + red = UOp(Ops.REDUCE, dtypes.float, (load, ridx1, ridx2), Ops.ADD) self.check(load, "(r0*1568)", - "((r2<1)&(r1<6))") + "((r2<1)&(r1<6))", + extra=(red,)) def test_valid_becomes_const1_z3(self): from z3 import Ints, Solver, And, If, Not, unsat @@ -483,5 +486,54 @@ class TestDropTrueGate(unittest.TestCase): # the True gate should be dropped (INDEX should only have 2 sources) self.assertEqual(len(result.src), 2, "True gate should be dropped from INDEX") +class TestRangeShrink(unittest.TestCase): + def get_ranges(self, sink): + with Context(NOOPT=1, SPEC=0): + result = full_rewrite_to_sink(sink) + return [u for u in result.toposort() if u.op is Ops.RANGE] + + def test_range_shrink_single_guard(self): + # range 0..203 guarded by r < 4 everywhere -> shrink to 0..3 + r = Range(0, 204) + load = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r) + ranges = self.get_ranges(load.sink()) + self.assertEqual(len(ranges), 1) + self.assertEqual(ranges[0].src[0].arg, 4) + + def test_range_shrink_picks_max_guard(self): + # two loads guard the same range with r < 4 and r < 8 -> shrink to max(4, 8) = 8 + r = Range(0, 204) + load1 = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r) + load2 = get_gated_load_uop(r < UOp.const(dtypes.index, 8), r) + ranges = self.get_ranges(UOp.sink(load1, load2)) + self.assertEqual(len(ranges), 1) + self.assertEqual(ranges[0].src[0].arg, 8) + + def test_range_no_shrink_guard_ge_max(self): + # guard r < 300 with range max 204 -> no shrink (guard doesn't constrain) + r = Range(0, 204) + load = get_gated_load_uop(r < UOp.const(dtypes.index, 300), r) + ranges = self.get_ranges(load.sink()) + self.assertEqual(len(ranges), 1) + self.assertEqual(ranges[0].src[0].arg, 204) + + def test_range_no_shrink_when_unguarded_elsewhere(self): + # one load guards r < 4, but another load uses r without a gate -> no shrink + r = Range(0, 204) + load1 = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r) + load2 = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.PARAM, dtypes.float.ptr(), arg=1).index(r, ptr=True),)) + ranges = self.get_ranges(UOp.sink(load1, load2)) + self.assertEqual(len(ranges), 1) + self.assertEqual(ranges[0].src[0].arg, 204) + + def test_range_no_shrink_when_used_in_reduce(self): + # range used in both a gated load AND directly in the reduce expression -> no shrink + r = Range(0, 204) + gated_load = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r) + red = UOp(Ops.REDUCE, dtypes.float, (r.cast(dtypes.float) + gated_load, r), Ops.ADD) + ranges = self.get_ranges(red.sink()) + self.assertEqual(len(ranges), 1) + self.assertEqual(ranges[0].src[0].arg, 204) + if __name__ == '__main__': unittest.main() diff --git a/test/null/test_uop_graph.py b/test/null/test_uop_graph.py index 6691091ac1..95d7742e9f 100644 --- a/test/null/test_uop_graph.py +++ b/test/null/test_uop_graph.py @@ -423,10 +423,12 @@ class TestUOpGraph(unittest.TestCase): d0 = UOp(Ops.PARAM, dtypes.long.ptr(), (), 0) ld = d0.index(ridx0.valid(ridx0<50)) w = (ridx0<50).where(ld, 5) - uops = to_uops_list([w]) + # prevent ridx0 from being shrunk + red = UOp(Ops.REDUCE, dtypes.long, (ridx0.cast(dtypes.long), ridx0), Ops.ADD) + uops = to_uops_list([w, red]) for u in uops: assert u.op is not Ops.WHERE - if u.op is Ops.LOAD: assert u.src[1].arg==5 + if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5 def test_where_on_gated_load_folds_swapped_branches(self): ridx0 = UOp.range(100, 0) @@ -444,10 +446,12 @@ class TestUOpGraph(unittest.TestCase): gate_idx = ridx0.valid((ridx0<50)) ld = d0.index(gate_idx).cast(dtypes.float) w = (ridx0<50).where(ld, 5.0) - uops = to_uops_list([w]) + # prevent ridx0 from being shrunk + red = UOp(Ops.REDUCE, dtypes.long, (ridx0.cast(dtypes.long), ridx0), Ops.ADD) + uops = to_uops_list([w, red]) for u in uops: assert u.op is not Ops.WHERE - if u.op is Ops.LOAD: assert u.src[1].arg == 5 + if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5 def test_where_in_store_becomes_gate(self): ridx0 = UOp.range(100, 0) diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index fd526d270b..c62df14f04 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -48,7 +48,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic") # optimize (schedule) the AST - sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges") + sink = graph_rewrite(sink, pm_simplify_ranges, ctx={}, name="simplify ranges") # do postrange optimization, BEAM or hand_coded_optimizations sink = apply_opts(sink, ren) diff --git a/tinygrad/codegen/simplify.py b/tinygrad/codegen/simplify.py index 1d464b2055..401e76e87b 100644 --- a/tinygrad/codegen/simplify.py +++ b/tinygrad/codegen/simplify.py @@ -1,4 +1,5 @@ import itertools +from typing import Callable from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start from tinygrad.uop.symbolic import symbolic from tinygrad.helpers import partition @@ -36,22 +37,32 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None: u = nidx return u +def mark_gated(ctx, idx): + if idx.src[1].op is Ops.WHERE: + x, cond = idx.src[1].get_idx(), idx.src[1].get_valid() + # get all ranges r with guards "r < c" for some const c + guards = {r:c for v in cond.split_uop(Ops.AND) if v.op is Ops.CMPLT and (r:=v.src[0]).op is Ops.RANGE and (c:=v.src[1]).op is Ops.CONST} + else: x, guards = idx, {} + # ensure that c actually limits the range, and that we choose max(c_i) + ctx |= {r:c for r,c in guards.items() if c.arg < r.src[0].arg and (r not in ctx or ctx[r].arg < c.arg)} + # but if a range is ever ungated, we cannot shrink it + ctx |= {r:r.src[0] for r in x.ranges if r not in guards} + pm_simplify_ranges = PatternMatcher([ (UPat((Ops.END, Ops.REDUCE), name="u"), simplify_merge_adjacent), + (UPat(Ops.INDEX, name="idx"), mark_gated), + # reduce ranges can't be shrunk + (UPat(Ops.REDUCE, name="red"), lambda ctx, red: ctx.update({r:r.src[0] for r in red.src[1:]})), + (UPat(Ops.SINK, name="x"), lambda ctx, x: do_substitute(ctx, x, lambda r,c: r.replace(src=(c,)))), ]) def mark_range_mod(ctx:dict[UOp, UOp|None], r:UOp, c:UOp) -> None: if r not in ctx and r.src[0].op is Ops.CONST and r.src[0].divides(c.arg) is not None: ctx[r] = c -def do_substitute(ctx:dict[UOp, UOp|None], x: UOp) -> UOp|None: - subs = {} - for k,v in ctx.items(): - if v is not None: - subs[k] = k.replace(src=(k.src[0]//v,), arg=k.arg[0:-1]+(0,k.arg[-1]))*v + k.replace(src=(v,), arg=k.arg[0:-1]+(1,k.arg[-1])) - if not len(subs): return None - ret = x.substitute(subs).simplify() +def do_substitute(ctx:dict, x: UOp, sub_fxn:Callable[[UOp, UOp], UOp]) -> UOp|None: + ret = x.substitute({k:sub_fxn(k,v) for k,v in ctx.items() if v is not None}) ctx.clear() - return ret + return None if ret is x else ret.simplify() def dont_sub_ranges_for_image(ctx:dict[UOp, UOp|None], x:UOp) -> None: if isinstance(x.src[0].src[0].dtype, ImageDType): @@ -60,7 +71,8 @@ def dont_sub_ranges_for_image(ctx:dict[UOp, UOp|None], x:UOp) -> None: pm_split_ranges = PatternMatcher([ (UPat(Ops.RANGE, name="r")%UPat.cvar("c"), mark_range_mod), (UPat(Ops.STORE, name="x"), dont_sub_ranges_for_image), - (UPat(Ops.SINK, name="x"), do_substitute), + (UPat(Ops.SINK, name="x"), lambda ctx, x: do_substitute(ctx, x, + lambda k,v: k.replace(src=(k.src[0]//v,), arg=k.arg[0:-1]+(0,k.arg[-1]))*v + k.replace(src=(v,), arg=k.arg[0:-1]+(1,k.arg[-1])))), ]) # **** reduce simplification ****