mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Improved div folding (#7996)
* First version of div_mod folding together * Working version with old div folding behaviour * Test is fixed * Fix linting * Happy mypy --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -205,6 +205,13 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((7+9*Variable("x",0,2)+9*Variable("y",0,2)+Variable("z",0,2))%10, 3, 9, "(((z+(x*-1))+(y*-1))+7)")
|
||||
self.helper_test_variable((10+12*Variable("x",0,2)+Variable("y", 0, 4)%3)%13, 8, 12, "(((x*-1)+(y%3))+10)")
|
||||
|
||||
def test_div_congruence(self):
|
||||
self.helper_test_variable((3+3*Variable("a",0,3))//4, 0, 3, "a")
|
||||
self.helper_test_variable((18+17*Variable("a",0,2)+17)//18, 1, 3, "(a+1)")
|
||||
|
||||
def test_div_congruence_multiple_vars(self):
|
||||
self.helper_test_variable((9+(9+10)*Variable("x",0,3)+(8+10)*Variable("y",0,2))//10, 0, 10, "((x*2)+(y*2))")
|
||||
|
||||
def test_mod_binary_expression(self):
|
||||
self.helper_test_variable((3+Variable("a",0,1))%4, 0, 3, "((a*-3)+3)")
|
||||
self.helper_test_variable((3+Variable("a",4,5))%4, 0, 3, "((a*-3)+15)")
|
||||
@@ -426,8 +433,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
def test_div_mod_recombine_folded_mod(self):
|
||||
a = Variable("a", 0, 2)
|
||||
b = Variable("b", 0, 100)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.helper_test_variable((31 * a + 1) % 30 + ((31 * a + 1) // 30) * 30, 1, 63, "((a*31)+1)")
|
||||
self.helper_test_variable((31 * a + 1) % 30 + ((31 * a + 1) // 30) * 30, 1, 63, "((a*31)+1)")
|
||||
with self.assertRaises(AssertionError):
|
||||
self.helper_test_variable((31 * b + 1) % 18 + ((31 * b + 1) // 18) * 18, 1, 3101, "((b*31)+1)")
|
||||
|
||||
@@ -450,11 +456,9 @@ class TestSymbolic(unittest.TestCase):
|
||||
unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4
|
||||
self.helper_test_variable(unrolled_div, 0, 2, "gidx")
|
||||
|
||||
# TODO: fix this, it has only one term and is no longer an add chain
|
||||
with self.assertRaises(AssertionError):
|
||||
gidx = Variable("gidx", 0, 1)
|
||||
unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4
|
||||
self.helper_test_variable(unrolled_div, 0, 1, "gidx")
|
||||
gidx = Variable("gidx", 0, 1)
|
||||
unrolled_div = (gidx)//4+(gidx+2)//4+(gidx+3)//4+(gidx+1)//4
|
||||
self.helper_test_variable(unrolled_div, 0, 1, "gidx")
|
||||
|
||||
def test_arange_unrolled2(self):
|
||||
gidx = Variable("gidx", 0, 2559)
|
||||
|
||||
Reference in New Issue
Block a user