don't modify the ranges on reduce rewrite (#10062)

* bug in div range folding

* simpler

* oh, this is right for indexing, but the div mod folding needs to be fixed

* reenable

* Passing test_complexity_w_unroll2 (#10068)

* Passing

* remove non_folded_divs

* Add check for negative tern in div folding

* Add test

* bump that limit

* fix casted

---------

Co-authored-by: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com>
This commit is contained in:
George Hotz
2025-04-28 12:01:19 -04:00
committed by GitHub
parent 5130759605
commit 690dac79b5
5 changed files with 49 additions and 46 deletions

View File

@@ -2,7 +2,7 @@
import unittest
import torch
import numpy as np
from tinygrad.helpers import getenv, Context
from tinygrad.helpers import getenv, Context, GlobalCounters
if getenv("TINY_BACKEND2"):
import extra.torch_backend.backend2
device = "cpu"
@@ -167,6 +167,7 @@ class TestTorchBackend(unittest.TestCase):
def test_mnist_index(self):
with Context(FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
GlobalCounters.reset()
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
X_train = torch.tensor(X_train.float().numpy(), device=device)
@@ -174,6 +175,7 @@ class TestTorchBackend(unittest.TestCase):
samples = torch.randint(0, X_train.shape[0], (32,))
X,Y = X_train[samples], Y_train[samples]
X.cpu(), Y.cpu()
self.assertLessEqual(GlobalCounters.global_ops, 10_000_000)
if __name__ == "__main__":
unittest.main()

View File

@@ -42,7 +42,7 @@ class TestArange(unittest.TestCase):
if Device.default.renderer.has_local:
# TODO: fix limit
def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=81920)
def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=100000)
def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496)
def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0)

View File

@@ -464,6 +464,11 @@ class TestSymbolic(unittest.TestCase):
unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4
self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)")
def test_arange_unrolled4_mul(self):
gidx = Variable("gidx", 0, 2559)
unrolled_div = 2*((gidx+2561)//4)+2*((gidx+2562)//4)+2*((gidx+2560)//4)+2*((gidx+2559)//4)
self.helper_test_variable(unrolled_div, 5118, 10236, "((gidx*2)+5118)")
def test_arange_unrolled4_small(self):
gidx = Variable("gidx", 0, 3)
unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4
@@ -482,6 +487,11 @@ class TestSymbolic(unittest.TestCase):
unrolled_div = (gidx+2559)//2+(gidx+2560)//2+3
self.helper_test_variable(unrolled_div, 2562, 5121, "(gidx+2562)")
def test_arange_unrolled2_neg(self):
ridx = Variable("ridx", 0, 255)
unrolled_div = -((255-ridx)//2) - ((256-ridx)//2)
self.helper_test_variable(unrolled_div, -255, 0, "(ridx+-255)")
def test_gated_load(self):
idx = Variable("idx", 0, 24)
self.helper_test_variable(idx//4, 0, 6, "(idx//4)")

View File

@@ -344,24 +344,6 @@ def no_vectorized_reduce(inp:UOp, red:UOp):
alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount))
return UOp(Ops.VECTORIZE, red.dtype, alus)
def range_fold_lo(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
# psuedo code: sum(val if i < cut else 0) for i in range(lo, hi, st))
total = (hi-lo+st-1) // st # real count in the range
length = ((cut-lo+st-1) // st).maximum(0).minimum(total)
return length.cast(val.dtype) * val
def range_fold_hi(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
# psuedo code: sum(val if i >= cut else 0) for i in range(lo, hi, st))
# TODO: this function is so tricky and still probably wrong. test it
total = (hi-lo+st-1) // st # real count in the range
length = ((lo-cut+total*st)//st).maximum(0).minimum(total) # number in cut
return length.cast(val.dtype) * val
def index_fold(buf:UOp, r:UOp, idx:UOp, r2:UOp) -> UOp|None:
if r.arg != r2.arg: return None
base_idx = (idx-r2.src[0])//r2.src[2] # indexed from 0 to the length of the range
return buf.index(base_idx.cast(r.dtype)*r.src[2] + r.src[0], (idx >= r2.src[0]) & (idx < r2.src[1]))
def reduce_rangeless(red:UOp):
# TODO: share code with reduce_unparented
if red.arg not in {Ops.ADD, Ops.MAX}: return None
@@ -370,22 +352,24 @@ def reduce_rangeless(red:UOp):
ret = red.src[0]
if red.arg is Ops.ADD:
for r in red.src[1:]:
total = (r.src[1]-r.src[0]+r.src[2]-1) // r.src[2] # real count in the range
ret = ret * total.cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
pm_reduce_collapse = PatternMatcher([
# put third arg in range
(UPat(Ops.RANGE, src=(UPat.var(), UPat.var()), name="r"), lambda r: r.replace(src=r.src+(UOp.const(r.dtype, 1),))),
# mul to range
(UPat.var("x") * UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]*x, r.src[1]*x, r.src[2]*x))),
# add to range
(UPat.var("x") + UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]+x, r.src[1]+x, r.src[2]))),
# fold the range with 0 in either the true or false slot
((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \
.where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True), range_fold_lo),
((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \
.where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), range_fold_hi),
# lift x+y out of reduce on lt
((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None),
# lift x*y out of reduce
((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
# lift x+y out of reduce on ne
((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None),
# fold the range
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
lambda r,cut,val: (r.src[1]-cut).maximum(0).minimum(r.src[1]-r.src[0]).cast(val.dtype) * val),
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True),
lambda r,cut,val: (cut-r.src[0]).maximum(0).minimum(r.src[1]-r.src[0]).cast(val.dtype) * val),
# devectorize REDUCE
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
# REDUCE on ADD
@@ -400,11 +384,10 @@ pm_reduce_collapse = PatternMatcher([
UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
# INDEX on RANGE / gated RANGE
(UPat.var("buf").index(UPat(Ops.RANGE, name="r"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r2"))), index_fold),
(UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= r.src[0]) & (idx.cast(r.dtype) < r.src[1]))),
# index/load. TODO: this is more aggressive than needed
(UPat((Ops.INDEX, Ops.LOAD), name="alu"), no_vectorized_alu),
# cast on RANGE (fix torch indexing)
(UPat(Ops.RANGE, name="r").cast(name="c"), lambda r,c: r.replace(src=tuple([x.cast(c.dtype) for x in r.src]), dtype=c.dtype)),
# AND on WHERE
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),

View File

@@ -4,7 +4,7 @@ import math, operator, struct, functools
from collections import defaultdict
from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
from tinygrad.dtype import ConstType, dtypes, PtrDType
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod
from tinygrad.codegen.transcendental import xpow
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
@@ -75,26 +75,33 @@ def split_uop(x:UOp, sep:Ops):
for s in x.src: yield from split_uop(s, sep)
else: yield x
def fold_unrolled_divs(divs:UOp):
def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
# div pattern in unrolled arange
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
for u in add_chain:
seen_const, ans, offset = [], None, 0
for u in split_uop(divs, Ops.ADD):
if fac!=1:
if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None
u = u.src[0]
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
if denominator is None: denominator = u.src[1].arg
if denominator != u.src[1].arg: return None
if (s0:=u.src[0]).vmin < 0: return None
# assumed CONST is the last of an ADD
if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
seen_const.append(s0.src[1].arg)
if s0.op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
const = s0.src[1].arg
offset += cdiv(const, denominator)
seen_const.append(cmod(const, denominator))
s0 = s0.src[0]
else: seen_const.append(0)
if ans is None: ans = s0
if ans is not s0: return None
if denominator is None: return None
if ans is None: return None
# the first (denominator-len(seen_const)) terms may have been folded to 0 already
for i in range(denominator-len(seen_const)):
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
if sorted(seen_const)==list(range(denominator)):
return fac*(ans + offset)
return None
def lt_folding(x:UOp, c:int) -> UOp|None:
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
@@ -265,7 +272,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
# *** rules from symbolic ***
# unrolled arange div folding
(UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
((UPat() + UPat()//UPat.cvar("d", vec=False)).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)),
((UPat() + (UPat()//UPat.cvar("d", vec=False))*UPat.cvar("c")).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)),
# generic lt folding
(UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
# canonicalize a simplex with positive coefficients > 0