different valid order (#7589)

in simplify_valid, we start with valids that are in others' parent so the others is more likely to be simplified
This commit is contained in:
chenyu
2024-11-07 20:27:56 -05:00
committed by GitHub
parent dc7b0e2bb7
commit a1dfd288bb
2 changed files with 28 additions and 2 deletions

View File

@@ -1,4 +1,4 @@
import unittest
import unittest, itertools
from typing import Tuple
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
@@ -79,6 +79,26 @@ class TestValidIdxSimplification(unittest.TestCase):
valid = alu0.lt(57) & alu0.ge(1)
self.assertIsNone(simplify_valid(valid))
def test_valid_order_matters1(self):
ridx0 = Range(0, 2)
v0 = ridx0.lt(1)
v1 = ((ridx0*5+1)%6).lt(5)
self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)")
self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)")
def test_valid_order_matters2(self):
gidx0 = Special("gidx0", 13)
gidx1 = Special("gidx1", 13)
ridx0 = Range(0, 4)
alu0 = (gidx1+(ridx0*13))
v0 = ((gidx0+11)%14).lt(11)
v1 = ((alu0+((gidx0+39)//42))%14).lt(11)
v2 = gidx0.lt(3)
v3 = alu0.lt(42)
for v in itertools.permutations([v0,v1,v2,v3]):
self.assertEqual(simplify_valid(v[0]&v[1]&v[2]&v[3]).render(), "False")
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
load = full_graph_rewrite(load.sink()).src[0]