different valid order (#7589)

in simplify_valid, we start with valids that are in others' parent so the others is more likely to be simplified
This commit is contained in:
chenyu
2024-11-07 20:27:56 -05:00
committed by GitHub
parent dc7b0e2bb7
commit a1dfd288bb
2 changed files with 28 additions and 2 deletions

View File

@@ -1,4 +1,4 @@
import unittest
import unittest, itertools
from typing import Tuple
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
@@ -79,6 +79,26 @@ class TestValidIdxSimplification(unittest.TestCase):
valid = alu0.lt(57) & alu0.ge(1)
self.assertIsNone(simplify_valid(valid))
def test_valid_order_matters1(self):
ridx0 = Range(0, 2)
v0 = ridx0.lt(1)
v1 = ((ridx0*5+1)%6).lt(5)
self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)")
self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)")
def test_valid_order_matters2(self):
gidx0 = Special("gidx0", 13)
gidx1 = Special("gidx1", 13)
ridx0 = Range(0, 4)
alu0 = (gidx1+(ridx0*13))
v0 = ((gidx0+11)%14).lt(11)
v1 = ((alu0+((gidx0+39)//42))%14).lt(11)
v2 = gidx0.lt(3)
v3 = alu0.lt(42)
for v in itertools.permutations([v0,v1,v2,v3]):
self.assertEqual(simplify_valid(v[0]&v[1]&v[2]&v[3]).render(), "False")
class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
load = full_graph_rewrite(load.sink()).src[0]

View File

@@ -992,10 +992,16 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
return uop
def _valid_priority(v: UOp, valids:List[UOp]):
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids)
except ValueError: return 0
def simplify_valid(valid:UOp) -> Optional[UOp]:
ret:List[UOp] = []
something_changed = False
for stmt in split_uop(valid, BinaryOps.AND):
valids = list(split_uop(valid, Ops.AND))
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
if ret[-1] is not stmt: something_changed = True
return functools.reduce(operator.and_, ret) if something_changed else None