mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Revert "shrink guarded ranges" (#15271)
This commit is contained in:
committed by
GitHub
parent
d9951e2f8e
commit
7cf4b16c91
@@ -46,9 +46,9 @@ class TestHelpers(unittest.TestCase):
|
||||
self.assertTrue((rng+2).is_increasing())
|
||||
|
||||
class TestValidIdxSimplification(unittest.TestCase):
|
||||
def check(self, load, sidx, svalid, extra=()):
|
||||
def check(self, load, sidx, svalid):
|
||||
with Context(NOOPT=1, SPEC=0):
|
||||
load = full_rewrite_to_sink(UOp.sink(load, *extra)).src[0]
|
||||
load = full_rewrite_to_sink(load.sink()).src[0]
|
||||
idx, valid = load.src[0].src[1], load.src[0].src[2]
|
||||
check_uop_against_string(self, idx, sidx)
|
||||
check_uop_against_string(self, valid, svalid)
|
||||
@@ -156,12 +156,9 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
idx = (alu15*-31)+(((((alu11+218)//224)+ridx0)%30)*1568)
|
||||
valid = (ridx2<1)&(ridx1<6)
|
||||
load = get_gated_load_uop(valid, idx)
|
||||
# prevent ridx1 and ridx2 from being shrunk
|
||||
red = UOp(Ops.REDUCE, dtypes.float, (load, ridx1, ridx2), Ops.ADD)
|
||||
self.check(load,
|
||||
"(r0*1568)",
|
||||
"((r2<1)&(r1<6))",
|
||||
extra=(red,))
|
||||
"((r2<1)&(r1<6))")
|
||||
|
||||
def test_valid_becomes_const1_z3(self):
|
||||
from z3 import Ints, Solver, And, If, Not, unsat
|
||||
@@ -486,54 +483,5 @@ class TestDropTrueGate(unittest.TestCase):
|
||||
# the True gate should be dropped (INDEX should only have 2 sources)
|
||||
self.assertEqual(len(result.src), 2, "True gate should be dropped from INDEX")
|
||||
|
||||
class TestRangeShrink(unittest.TestCase):
|
||||
def get_ranges(self, sink):
|
||||
with Context(NOOPT=1, SPEC=0):
|
||||
result = full_rewrite_to_sink(sink)
|
||||
return [u for u in result.toposort() if u.op is Ops.RANGE]
|
||||
|
||||
def test_range_shrink_single_guard(self):
|
||||
# range 0..203 guarded by r < 4 everywhere -> shrink to 0..3
|
||||
r = Range(0, 204)
|
||||
load = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
|
||||
ranges = self.get_ranges(load.sink())
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 4)
|
||||
|
||||
def test_range_shrink_picks_max_guard(self):
|
||||
# two loads guard the same range with r < 4 and r < 8 -> shrink to max(4, 8) = 8
|
||||
r = Range(0, 204)
|
||||
load1 = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
|
||||
load2 = get_gated_load_uop(r < UOp.const(dtypes.index, 8), r)
|
||||
ranges = self.get_ranges(UOp.sink(load1, load2))
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 8)
|
||||
|
||||
def test_range_no_shrink_guard_ge_max(self):
|
||||
# guard r < 300 with range max 204 -> no shrink (guard doesn't constrain)
|
||||
r = Range(0, 204)
|
||||
load = get_gated_load_uop(r < UOp.const(dtypes.index, 300), r)
|
||||
ranges = self.get_ranges(load.sink())
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 204)
|
||||
|
||||
def test_range_no_shrink_when_unguarded_elsewhere(self):
|
||||
# one load guards r < 4, but another load uses r without a gate -> no shrink
|
||||
r = Range(0, 204)
|
||||
load1 = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
|
||||
load2 = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.PARAM, dtypes.float.ptr(), arg=1).index(r, ptr=True),))
|
||||
ranges = self.get_ranges(UOp.sink(load1, load2))
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 204)
|
||||
|
||||
def test_range_no_shrink_when_used_in_reduce(self):
|
||||
# range used in both a gated load AND directly in the reduce expression -> no shrink
|
||||
r = Range(0, 204)
|
||||
gated_load = get_gated_load_uop(r < UOp.const(dtypes.index, 4), r)
|
||||
red = UOp(Ops.REDUCE, dtypes.float, (r.cast(dtypes.float) + gated_load, r), Ops.ADD)
|
||||
ranges = self.get_ranges(red.sink())
|
||||
self.assertEqual(len(ranges), 1)
|
||||
self.assertEqual(ranges[0].src[0].arg, 204)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -423,12 +423,10 @@ class TestUOpGraph(unittest.TestCase):
|
||||
d0 = UOp(Ops.PARAM, dtypes.long.ptr(), (), 0)
|
||||
ld = d0.index(ridx0.valid(ridx0<50))
|
||||
w = (ridx0<50).where(ld, 5)
|
||||
# prevent ridx0 from being shrunk
|
||||
red = UOp(Ops.REDUCE, dtypes.long, (ridx0.cast(dtypes.long), ridx0), Ops.ADD)
|
||||
uops = to_uops_list([w, red])
|
||||
uops = to_uops_list([w])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg==5
|
||||
if u.op is Ops.LOAD: assert u.src[1].arg==5
|
||||
|
||||
def test_where_on_gated_load_folds_swapped_branches(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
@@ -446,12 +444,10 @@ class TestUOpGraph(unittest.TestCase):
|
||||
gate_idx = ridx0.valid((ridx0<50))
|
||||
ld = d0.index(gate_idx).cast(dtypes.float)
|
||||
w = (ridx0<50).where(ld, 5.0)
|
||||
# prevent ridx0 from being shrunk
|
||||
red = UOp(Ops.REDUCE, dtypes.long, (ridx0.cast(dtypes.long), ridx0), Ops.ADD)
|
||||
uops = to_uops_list([w, red])
|
||||
uops = to_uops_list([w])
|
||||
for u in uops:
|
||||
assert u.op is not Ops.WHERE
|
||||
if u.op is Ops.LOAD and u.src[0].src[0].op is Ops.PARAM: assert u.src[1].arg == 5
|
||||
if u.op is Ops.LOAD: assert u.src[1].arg == 5
|
||||
|
||||
def test_where_in_store_becomes_gate(self):
|
||||
ridx0 = UOp.range(100, 0)
|
||||
|
||||
@@ -48,7 +48,7 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
sink = graph_rewrite(sink, sym+pm_flatten_range, name="initial symbolic")
|
||||
|
||||
# optimize (schedule) the AST
|
||||
sink = graph_rewrite(sink, pm_simplify_ranges, ctx={}, name="simplify ranges")
|
||||
sink = graph_rewrite(sink, pm_simplify_ranges, name="simplify ranges")
|
||||
|
||||
# do postrange optimization, BEAM or hand_coded_optimizations
|
||||
sink = apply_opts(sink, ren)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import itertools
|
||||
from typing import Callable
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops, graph_rewrite, _substitute, range_start
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import partition
|
||||
@@ -37,32 +36,22 @@ def simplify_merge_adjacent(u:UOp) -> UOp|None:
|
||||
u = nidx
|
||||
return u
|
||||
|
||||
def mark_gated(ctx, idx):
|
||||
if idx.src[1].op is Ops.WHERE:
|
||||
x, cond = idx.src[1].get_idx(), idx.src[1].get_valid()
|
||||
# get all ranges r with guards "r < c" for some const c
|
||||
guards = {r:c for v in cond.split_uop(Ops.AND) if v.op is Ops.CMPLT and (r:=v.src[0]).op is Ops.RANGE and (c:=v.src[1]).op is Ops.CONST}
|
||||
else: x, guards = idx, {}
|
||||
# ensure that c actually limits the range, and that we choose max(c_i)
|
||||
ctx |= {r:c for r,c in guards.items() if c.arg < r.src[0].arg and (r not in ctx or ctx[r].arg < c.arg)}
|
||||
# but if a range is ever ungated, we cannot shrink it
|
||||
ctx |= {r:r.src[0] for r in x.ranges if r not in guards}
|
||||
|
||||
pm_simplify_ranges = PatternMatcher([
|
||||
(UPat((Ops.END, Ops.REDUCE), name="u"), simplify_merge_adjacent),
|
||||
(UPat(Ops.INDEX, name="idx"), mark_gated),
|
||||
# reduce ranges can't be shrunk
|
||||
(UPat(Ops.REDUCE, name="red"), lambda ctx, red: ctx.update({r:r.src[0] for r in red.src[1:]})),
|
||||
(UPat(Ops.SINK, name="x"), lambda ctx, x: do_substitute(ctx, x, lambda r,c: r.replace(src=(c,)))),
|
||||
])
|
||||
|
||||
def mark_range_mod(ctx:dict[UOp, UOp|None], r:UOp, c:UOp) -> None:
|
||||
if r not in ctx and r.src[0].op is Ops.CONST and r.src[0].divides(c.arg) is not None: ctx[r] = c
|
||||
|
||||
def do_substitute(ctx:dict, x: UOp, sub_fxn:Callable[[UOp, UOp], UOp]) -> UOp|None:
|
||||
ret = x.substitute({k:sub_fxn(k,v) for k,v in ctx.items() if v is not None})
|
||||
def do_substitute(ctx:dict[UOp, UOp|None], x: UOp) -> UOp|None:
|
||||
subs = {}
|
||||
for k,v in ctx.items():
|
||||
if v is not None:
|
||||
subs[k] = k.replace(src=(k.src[0]//v,), arg=k.arg[0:-1]+(0,k.arg[-1]))*v + k.replace(src=(v,), arg=k.arg[0:-1]+(1,k.arg[-1]))
|
||||
if not len(subs): return None
|
||||
ret = x.substitute(subs).simplify()
|
||||
ctx.clear()
|
||||
return None if ret is x else ret.simplify()
|
||||
return ret
|
||||
|
||||
def dont_sub_ranges_for_image(ctx:dict[UOp, UOp|None], x:UOp) -> None:
|
||||
if isinstance(x.src[0].src[0].dtype, ImageDType):
|
||||
@@ -71,8 +60,7 @@ def dont_sub_ranges_for_image(ctx:dict[UOp, UOp|None], x:UOp) -> None:
|
||||
pm_split_ranges = PatternMatcher([
|
||||
(UPat(Ops.RANGE, name="r")%UPat.cvar("c"), mark_range_mod),
|
||||
(UPat(Ops.STORE, name="x"), dont_sub_ranges_for_image),
|
||||
(UPat(Ops.SINK, name="x"), lambda ctx, x: do_substitute(ctx, x,
|
||||
lambda k,v: k.replace(src=(k.src[0]//v,), arg=k.arg[0:-1]+(0,k.arg[-1]))*v + k.replace(src=(v,), arg=k.arg[0:-1]+(1,k.arg[-1])))),
|
||||
(UPat(Ops.SINK, name="x"), do_substitute),
|
||||
])
|
||||
|
||||
# **** reduce simplification ****
|
||||
|
||||
Reference in New Issue
Block a user