shrink guarded ranges, try 2 (#15272)

This commit is contained in:
Christopher Milan
2026-03-14 01:24:05 -07:00
committed by GitHub
parent 7cf4b16c91
commit dabdc986df
4 changed files with 92 additions and 17 deletions

View File

@@ -46,9 +46,9 @@ class TestHelpers(unittest.TestCase):
self.assertTrue((rng+2).is_increasing())
class TestValidIdxSimplification(unittest.TestCase):
def check(self, load, sidx, svalid):
def check(self, load, sidx, svalid, extra=()):
with Context(NOOPT=1, SPEC=0):
load = full_rewrite_to_sink(load.sink()).src[0]
load = full_rewrite_to_sink(UOp.sink(load, *extra)).src[0]
idx, valid = load.src[0].src[1], load.src[0].src[2]
check_uop_against_string(self, idx, sidx)
check_uop_against_string(self, valid, svalid)
@@ -156,9 +156,12 @@ class TestValidIdxSimplification(unittest.TestCase):
idx = (alu15*-31)+(((((alu11+218)//224)+ridx0)%30)*1568)
valid = (ridx2<1)&(ridx1<6)
load = get_gated_load_uop(valid, idx)
# prevent ridx1 and ridx2 from being shrunk
red = UOp(Ops.REDUCE, dtypes.float, (load, ridx1, ridx2), Ops.ADD)
self.check(load,
"(r0*1568)",
"((r2<1)&(r1<6))")
"((r2<1)&(r1<6))",
extra=(red,))
def test_valid_becomes_const1_z3(self):
from z3 import Ints, Solver, And, If, Not, unsat
@@ -483,5 +486,61 @@ class TestDropTrueGate(unittest.TestCase):
# the True gate should be dropped (INDEX should only have 2 sources)
self.assertEqual(len(result.src), 2, "True gate should be dropped from INDEX")
class TestRangeShrink(unittest.TestCase):
def get_ranges(self, sink):
with Context(NOOPT=1, SPEC=0):
result = full_rewrite_to_sink(sink)
return [u for u in result.toposort() if u.op is Ops.RANGE]
def test_range_shrink_single_guard(self):
# range 0..203 guarded by r < 4 everywhere -> shrink to 0..3
r = Range(0, 204)
load = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
ranges = self.get_ranges(load.sink())
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 4)
def test_range_shrink_picks_max_guard(self):
# two loads guard the same range with r < 4 and r < 8 -> shrink to max(4, 8) = 8
r = Range(0, 204)
load1 = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
load2 = get_gated_load_uop(r < UOp.const(dtypes.index, 8), r)
ranges = self.get_ranges(UOp.sink(load1, load2))
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 8)
def test_range_no_shrink_guard_ge_max(self):
# guard r < 300 with range max 204 -> no shrink (guard doesn't constrain)
r = Range(0, 204)
load = get_gated_load_uop(r < UOp.const(dtypes.index, 300), r)
ranges = self.get_ranges(load.sink())
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 204)
def test_range_no_shrink_when_unguarded_elsewhere(self):
# one load guards r < 4, but another load uses r without a gate -> no shrink
r = Range(0, 204)
load1 = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
load2 = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.PARAM, dtypes.float.ptr(), arg=1).index(r, ptr=True),))
ranges = self.get_ranges(UOp.sink(load1, load2))
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 204)
def test_range_no_shrink_when_used_in_reduce(self):
# range used in both a gated load AND directly in the reduce expression -> no shrink
r = Range(0, 204)
gated_load = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
red = UOp(Ops.REDUCE, dtypes.float, (r.cast(dtypes.float) + gated_load, r), Ops.ADD)
ranges = self.get_ranges(red.sink())
self.assertEqual(len(ranges), 1)
self.assertEqual(ranges[0].src[0].arg, 204)
def test_range_shrink_to_single_iteration(self):
# guard r < 1 shrinks range to 1 -> single iteration, range eliminated entirely
r = Range(0, 204)
load = get_gated_load_uop(r < UOp.const(dtypes.index, 1), r)
ranges = self.get_ranges(load.sink())
self.assertEqual(len(ranges), 0)
if __name__ == '__main__':
unittest.main()

View File

@@ -423,10 +423,12 @@ class TestUOpGraph(unittest.TestCase):
d0 = UOp(Ops.PARAM, dtypes.long.ptr(), (), 0)
ld = d0.index(ridx0.valid(ridx0<50))
w = (ridx0<50).where(ld, 5)
uops = to_uops_list([w])
# prevent ridx0 from being shrunk
red = UOp(Ops.REDUCE, dtypes.long, (ridx0.cast(dtypes.long), ridx0), Ops.ADD)
uops = to_uops_list([w, red])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD: assert u.src[1].arg==5
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5
def test_where_on_gated_load_folds_swapped_branches(self):
ridx0 = UOp.range(100, 0)
@@ -444,10 +446,12 @@ class TestUOpGraph(unittest.TestCase):
gate_idx = ridx0.valid((ridx0<50))
ld = d0.index(gate_idx).cast(dtypes.float)
w = (ridx0<50).where(ld, 5.0)
uops = to_uops_list([w])
# prevent ridx0 from being shrunk
red = UOp(Ops.REDUCE, dtypes.long, (ridx0.cast(dtypes.long), ridx0), Ops.ADD)
uops = to_uops_list([w, red])
for u in uops:
assert u.op is not Ops.WHERE
if u.op is Ops.LOAD: assert u.src[1].arg == 5
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5
def test_where_in_store_becomes_gate(self):
ridx0 = UOp.range(100, 0)