mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
different valid order (#7589)
in simplify_valid, we start with valids that are in others' parent so the others is more likely to be simplified
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import unittest
|
||||
import unittest, itertools
|
||||
from typing import Tuple
|
||||
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
|
||||
@@ -79,6 +79,26 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
valid = alu0.lt(57) & alu0.ge(1)
|
||||
self.assertIsNone(simplify_valid(valid))
|
||||
|
||||
def test_valid_order_matters1(self):
|
||||
ridx0 = Range(0, 2)
|
||||
v0 = ridx0.lt(1)
|
||||
v1 = ((ridx0*5+1)%6).lt(5)
|
||||
self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)")
|
||||
self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)")
|
||||
|
||||
def test_valid_order_matters2(self):
|
||||
gidx0 = Special("gidx0", 13)
|
||||
gidx1 = Special("gidx1", 13)
|
||||
ridx0 = Range(0, 4)
|
||||
alu0 = (gidx1+(ridx0*13))
|
||||
v0 = ((gidx0+11)%14).lt(11)
|
||||
v1 = ((alu0+((gidx0+39)//42))%14).lt(11)
|
||||
v2 = gidx0.lt(3)
|
||||
v3 = alu0.lt(42)
|
||||
|
||||
for v in itertools.permutations([v0,v1,v2,v3]):
|
||||
self.assertEqual(simplify_valid(v[0]&v[1]&v[2]&v[3]).render(), "False")
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
|
||||
Reference in New Issue
Block a user