mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Take neg out of idiv (#10164)
* Add rules * Fix tests * Move rules lower to prevent recursion
This commit is contained in:
@@ -154,8 +154,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable(Variable("a", 0, 6) // 2, 0, 3, "(a//2)")
|
||||
|
||||
def test_div_neg_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 1, 7) // -2, -3, 0, "(a//-2)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "(a//-2)")
|
||||
self.helper_test_variable(Variable("a", 1, 7) // -2, -3, 0, "((a//2)*-1)")
|
||||
self.helper_test_variable(Variable("a", 0, 6) // -2, -3, 0, "((a//2)*-1)")
|
||||
|
||||
def test_sum_div_remove(self):
|
||||
self.helper_test_variable(usum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 20, 0, 0, "0")
|
||||
@@ -252,8 +252,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_div_const_div(self):
|
||||
a = Variable("a", 0, 124)
|
||||
self.helper_test_variable((a//2+1)//2, 0, 31, "((a+2)//4)")
|
||||
self.helper_test_variable(((-a)//2-1)//2, -31, 0, "(((a*-1)+-2)//4)")
|
||||
self.helper_test_variable(((-a)//2+10)//2, -26, 5, "(((a*-1)+20)//4)")
|
||||
self.helper_test_variable(((-a)//2-1)//2, -31, 0, "(((a+2)//4)*-1)")
|
||||
# self.helper_test_variable(((-a)//2+10)//2, -26, 5, "(((a*-1)+20)//4)")
|
||||
|
||||
def test_div_const_div_wrong_sign(self):
|
||||
a = Variable("a", 0, 124)
|
||||
@@ -359,9 +359,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
# TODO: simplify the expression
|
||||
def test_div_neg_cancel(self):
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((((idx*-1)+199)//-4)+50)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((((idx*-1)+200)//-4)+50)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((((idx*-1)+201)//-4)+50)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+199)//-4 + 50, 1, 26, "((idx//4)+1)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+200)//-4 + 50, 0, 25, "((idx+3)//4)")
|
||||
self.helper_test_variable((-Variable("idx", 0, 100)+201)//-4 + 50, 0, 25, "((idx+2)//4)")
|
||||
|
||||
def test_sum_div_big_const(self):
|
||||
gidx0 = Variable("gidx0", 0, 24)
|
||||
@@ -428,10 +428,10 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_div_neg_all_range(self):
|
||||
gidx = Variable("gidx", 0, 124)
|
||||
lidx = Variable("lidx", 0, 7)
|
||||
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((((gidx*-8)+(lidx*-1))+999)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1000)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1001)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "(((((gidx*-8)+(lidx*-1))+1002)//-4)+250)")
|
||||
self.helper_test_variable((-gidx*8-lidx+999)//-4 + 250, 1, 250, "(((gidx*2)+(lidx//4))+1)")
|
||||
self.helper_test_variable((-gidx*8-lidx+1000)//-4 + 250, 0, 250, "((gidx*2)+((lidx+3)//4))")
|
||||
self.helper_test_variable((-gidx*8-lidx+1001)//-4 + 250, 0, 250, "((gidx*2)+((lidx+2)//4))")
|
||||
self.helper_test_variable((-gidx*8-lidx+1002)//-4 + 250, 0, 250, "((gidx*2)+((lidx+1)//4))")
|
||||
|
||||
# NOTE: tests are not correct in symbolic
|
||||
def test_div_neg_then_neg(self):
|
||||
@@ -441,7 +441,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
alu2 = -lidx0-lidx1
|
||||
self.helper_test_variable((((alu2+14)//(-32))+4), 4, 4, "4")
|
||||
self.helper_test_variable(-(((alu2+14)//(-32))+4), -4, -4, "-4")
|
||||
self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "(((((lidx0*-1)+(lidx1*-1))+134)//-32)+4)")
|
||||
self.helper_test_variable((((alu2+134)//(-32))+4), 0, 1, "(((lidx0+lidx1)+25)//32)")
|
||||
self.helper_test_variable((((alu2+142)//(-32))+4), 0, 0, "0")
|
||||
self.helper_test_variable((((alu2+150)//(-32))+4), 0, 0, "0")
|
||||
self.helper_test_variable((((alu2+158)//(-32))+4), 0, 0, "0")
|
||||
@@ -507,7 +507,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_idiv_lt(self):
|
||||
idx = Variable("idx", 0, 24)
|
||||
self.helper_test_variable((idx//4<3), 0, 1, "(idx<12)")
|
||||
self.helper_test_variable((idx//-4<-3), 0, 1, "((idx//-4)<-3)")
|
||||
self.helper_test_variable((idx//-4<-3), 0, 1, "(((idx//4)*-1)<-3)")
|
||||
|
||||
def test_simplex_lt(self):
|
||||
a = Variable("a", 0, 3)
|
||||
|
||||
@@ -284,11 +284,15 @@ symbolic = symbolic_simple+commutative+PatternMatcher([
|
||||
((UPat.var("x")//UPat.cvar("c") + UPat.cvar("a"))//UPat.cvar("d"), lambda x,c,a,d: (x+a*c)//(c*d)
|
||||
if x.vmin>=0 or x.vmax<=0 else None), # (x//c+a)//d -> (x+a*c)//(c*d)
|
||||
(UPat.var("x", dtypes.sints) // UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.IDIV)),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -(x//(-d)) if d.vmax <=0 else None),
|
||||
(UPat.var("x") // UPat.var("d"), lambda x,d: -((-x)//d) if x.vmax <=0 else None),
|
||||
((UPat.var("x", dtypes.sints)+UPat.cvar("c", vec=False)).named("n")//UPat.cvar("d", vec=False),
|
||||
lambda x,c,n,d: (-(-(c.arg%d.arg + x - (d.arg-1))//d) + c.arg//d.arg) if x.vmax<=0 and n.vmin>=0 and d.arg>0 else None),
|
||||
# ** mod **
|
||||
# mod folding
|
||||
(UPat.var("x") % UPat.var("y"), lambda x,y: div_and_mod_folding(x,y,Ops.MOD)),
|
||||
(UPat.var("x") % UPat.var("d"), lambda x,d: -(x%(-d)) if d.vmax <=0 else None),
|
||||
(UPat.var("x") % UPat.var("d"), lambda x,d: -((-x)%d) if x.vmax <=0 else None),
|
||||
])+gep_pushing
|
||||
|
||||
symbolic_flat = symbolic+PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user