mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
don't modify the ranges on reduce rewrite (#10062)
* bug in div range folding * simpler * oh, this is right for indexing, but the div mod folding needs to be fixed * reenable * Passing test_complexity_w_unroll2 (#10068) * Passing * remove non_folded_divs * Add check for negative tern in div folding * Add test * bump that limit * fix casted --------- Co-authored-by: Sieds Lykles <93992551+S-Lykles@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
import unittest
|
||||
import torch
|
||||
import numpy as np
|
||||
from tinygrad.helpers import getenv, Context
|
||||
from tinygrad.helpers import getenv, Context, GlobalCounters
|
||||
if getenv("TINY_BACKEND2"):
|
||||
import extra.torch_backend.backend2
|
||||
device = "cpu"
|
||||
@@ -167,6 +167,7 @@ class TestTorchBackend(unittest.TestCase):
|
||||
|
||||
def test_mnist_index(self):
|
||||
with Context(FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
|
||||
GlobalCounters.reset()
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
X_train = torch.tensor(X_train.float().numpy(), device=device)
|
||||
@@ -174,6 +175,7 @@ class TestTorchBackend(unittest.TestCase):
|
||||
samples = torch.randint(0, X_train.shape[0], (32,))
|
||||
X,Y = X_train[samples], Y_train[samples]
|
||||
X.cpu(), Y.cpu()
|
||||
self.assertLessEqual(GlobalCounters.global_ops, 10_000_000)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestArange(unittest.TestCase):
|
||||
|
||||
if Device.default.renderer.has_local:
|
||||
# TODO: fix limit
|
||||
def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=81920)
|
||||
def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=100000)
|
||||
def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496)
|
||||
|
||||
def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0)
|
||||
|
||||
@@ -464,6 +464,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
unrolled_div = (gidx+2561)//4+(gidx+2562)//4+(gidx+2560)//4+(gidx+2559)//4
|
||||
self.helper_test_variable(unrolled_div, 2559, 5118, "(gidx+2559)")
|
||||
|
||||
def test_arange_unrolled4_mul(self):
|
||||
gidx = Variable("gidx", 0, 2559)
|
||||
unrolled_div = 2*((gidx+2561)//4)+2*((gidx+2562)//4)+2*((gidx+2560)//4)+2*((gidx+2559)//4)
|
||||
self.helper_test_variable(unrolled_div, 5118, 10236, "((gidx*2)+5118)")
|
||||
|
||||
def test_arange_unrolled4_small(self):
|
||||
gidx = Variable("gidx", 0, 3)
|
||||
unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4
|
||||
@@ -482,6 +487,11 @@ class TestSymbolic(unittest.TestCase):
|
||||
unrolled_div = (gidx+2559)//2+(gidx+2560)//2+3
|
||||
self.helper_test_variable(unrolled_div, 2562, 5121, "(gidx+2562)")
|
||||
|
||||
def test_arange_unrolled2_neg(self):
|
||||
ridx = Variable("ridx", 0, 255)
|
||||
unrolled_div = -((255-ridx)//2) - ((256-ridx)//2)
|
||||
self.helper_test_variable(unrolled_div, -255, 0, "(ridx+-255)")
|
||||
|
||||
def test_gated_load(self):
|
||||
idx = Variable("idx", 0, 24)
|
||||
self.helper_test_variable(idx//4, 0, 6, "(idx//4)")
|
||||
|
||||
@@ -344,24 +344,6 @@ def no_vectorized_reduce(inp:UOp, red:UOp):
|
||||
alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount))
|
||||
return UOp(Ops.VECTORIZE, red.dtype, alus)
|
||||
|
||||
def range_fold_lo(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
|
||||
# psuedo code: sum(val if i < cut else 0) for i in range(lo, hi, st))
|
||||
total = (hi-lo+st-1) // st # real count in the range
|
||||
length = ((cut-lo+st-1) // st).maximum(0).minimum(total)
|
||||
return length.cast(val.dtype) * val
|
||||
|
||||
def range_fold_hi(lo:UOp, hi:UOp, st:UOp, cut:UOp, val:UOp) -> UOp:
|
||||
# psuedo code: sum(val if i >= cut else 0) for i in range(lo, hi, st))
|
||||
# TODO: this function is so tricky and still probably wrong. test it
|
||||
total = (hi-lo+st-1) // st # real count in the range
|
||||
length = ((lo-cut+total*st)//st).maximum(0).minimum(total) # number in cut
|
||||
return length.cast(val.dtype) * val
|
||||
|
||||
def index_fold(buf:UOp, r:UOp, idx:UOp, r2:UOp) -> UOp|None:
|
||||
if r.arg != r2.arg: return None
|
||||
base_idx = (idx-r2.src[0])//r2.src[2] # indexed from 0 to the length of the range
|
||||
return buf.index(base_idx.cast(r.dtype)*r.src[2] + r.src[0], (idx >= r2.src[0]) & (idx < r2.src[1]))
|
||||
|
||||
def reduce_rangeless(red:UOp):
|
||||
# TODO: share code with reduce_unparented
|
||||
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
@@ -370,22 +352,24 @@ def reduce_rangeless(red:UOp):
|
||||
ret = red.src[0]
|
||||
if red.arg is Ops.ADD:
|
||||
for r in red.src[1:]:
|
||||
total = (r.src[1]-r.src[0]+r.src[2]-1) // r.src[2] # real count in the range
|
||||
ret = ret * total.cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
|
||||
|
||||
pm_reduce_collapse = PatternMatcher([
|
||||
# put third arg in range
|
||||
(UPat(Ops.RANGE, src=(UPat.var(), UPat.var()), name="r"), lambda r: r.replace(src=r.src+(UOp.const(r.dtype, 1),))),
|
||||
# mul to range
|
||||
(UPat.var("x") * UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]*x, r.src[1]*x, r.src[2]*x))),
|
||||
# add to range
|
||||
(UPat.var("x") + UPat(Ops.RANGE, name="r"), lambda x,r: r.replace(src=(r.src[0]+x, r.src[1]+x, r.src[2]))),
|
||||
# fold the range with 0 in either the true or false slot
|
||||
((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \
|
||||
.where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True), range_fold_lo),
|
||||
((UPat(Ops.RANGE, src=(UPat.var("lo"), UPat.var("hi"), UPat.var("st"))) < UPat.cvar("cut")) \
|
||||
.where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True), range_fold_hi),
|
||||
# lift x+y out of reduce on lt
|
||||
((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None),
|
||||
# lift x*y out of reduce
|
||||
((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
|
||||
lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
|
||||
# lift x+y out of reduce on ne
|
||||
((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None),
|
||||
# fold the range
|
||||
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda r,cut,val: (r.src[1]-cut).maximum(0).minimum(r.src[1]-r.src[0]).cast(val.dtype) * val),
|
||||
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda r,cut,val: (cut-r.src[0]).maximum(0).minimum(r.src[1]-r.src[0]).cast(val.dtype) * val),
|
||||
# devectorize REDUCE
|
||||
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
|
||||
# REDUCE on ADD
|
||||
@@ -400,11 +384,10 @@ pm_reduce_collapse = PatternMatcher([
|
||||
UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
|
||||
# INDEX on RANGE / gated RANGE
|
||||
(UPat.var("buf").index(UPat(Ops.RANGE, name="r"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r2"))), index_fold),
|
||||
(UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
|
||||
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= r.src[0]) & (idx.cast(r.dtype) < r.src[1]))),
|
||||
# index/load. TODO: this is more aggressive than needed
|
||||
(UPat((Ops.INDEX, Ops.LOAD), name="alu"), no_vectorized_alu),
|
||||
# cast on RANGE (fix torch indexing)
|
||||
(UPat(Ops.RANGE, name="r").cast(name="c"), lambda r,c: r.replace(src=tuple([x.cast(c.dtype) for x in r.src]), dtype=c.dtype)),
|
||||
# AND on WHERE
|
||||
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
|
||||
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
|
||||
@@ -4,7 +4,7 @@ import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod
|
||||
from tinygrad.codegen.transcendental import xpow
|
||||
|
||||
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
|
||||
@@ -75,26 +75,33 @@ def split_uop(x:UOp, sep:Ops):
|
||||
for s in x.src: yield from split_uop(s, sep)
|
||||
else: yield x
|
||||
|
||||
def fold_unrolled_divs(divs:UOp):
|
||||
def fold_unrolled_divs(divs:UOp, denominator: int, fac=1) -> UOp|None:
|
||||
# div pattern in unrolled arange
|
||||
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
|
||||
add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None
|
||||
for u in add_chain:
|
||||
seen_const, ans, offset = [], None, 0
|
||||
for u in split_uop(divs, Ops.ADD):
|
||||
if fac!=1:
|
||||
if u.op is not Ops.MUL or u.src[1].op is not Ops.CONST or u.src[1].arg != fac: return None
|
||||
u = u.src[0]
|
||||
if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None
|
||||
if denominator is None: denominator = u.src[1].arg
|
||||
if denominator != u.src[1].arg: return None
|
||||
if (s0:=u.src[0]).vmin < 0: return None
|
||||
# assumed CONST is the last of an ADD
|
||||
if (s0:=u.src[0]).op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
|
||||
seen_const.append(s0.src[1].arg)
|
||||
if s0.op is Ops.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
|
||||
const = s0.src[1].arg
|
||||
offset += cdiv(const, denominator)
|
||||
seen_const.append(cmod(const, denominator))
|
||||
s0 = s0.src[0]
|
||||
else: seen_const.append(0)
|
||||
if ans is None: ans = s0
|
||||
if ans is not s0: return None
|
||||
if denominator is None: return None
|
||||
if ans is None: return None
|
||||
# the first (denominator-len(seen_const)) terms may have been folded to 0 already
|
||||
for i in range(denominator-len(seen_const)):
|
||||
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
|
||||
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
|
||||
if sorted(seen_const)==list(range(denominator)):
|
||||
return fac*(ans + offset)
|
||||
return None
|
||||
|
||||
def lt_folding(x:UOp, c:int) -> UOp|None:
|
||||
p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1)
|
||||
@@ -265,7 +272,8 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("x") * UPat.cvar("c1")) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
# *** rules from symbolic ***
|
||||
# unrolled arange div folding
|
||||
(UPat(Ops.ADD, name="divs", src=[UPat(), UPat(Ops.IDIV)]), fold_unrolled_divs),
|
||||
((UPat() + UPat()//UPat.cvar("d", vec=False)).named("divs"), lambda divs,d: fold_unrolled_divs(divs, d.arg)),
|
||||
((UPat() + (UPat()//UPat.cvar("d", vec=False))*UPat.cvar("c")).named("divs"), lambda divs,d,c: fold_unrolled_divs(divs, d.arg, c.arg)),
|
||||
# generic lt folding
|
||||
(UPat.var("x", dtypes.sints)<UPat.cvar("c", vec=False), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
|
||||
# canonicalize a simplex with positive coefficients > 0
|
||||
|
||||
Reference in New Issue
Block a user