mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
different valid order (#7589)
in simplify_valid, we start with valids that are in others' parent so the others is more likely to be simplified
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import unittest
|
||||
import unittest, itertools
|
||||
from typing import Tuple
|
||||
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
|
||||
@@ -79,6 +79,26 @@ class TestValidIdxSimplification(unittest.TestCase):
|
||||
valid = alu0.lt(57) & alu0.ge(1)
|
||||
self.assertIsNone(simplify_valid(valid))
|
||||
|
||||
def test_valid_order_matters1(self):
|
||||
ridx0 = Range(0, 2)
|
||||
v0 = ridx0.lt(1)
|
||||
v1 = ((ridx0*5+1)%6).lt(5)
|
||||
self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)")
|
||||
self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)")
|
||||
|
||||
def test_valid_order_matters2(self):
|
||||
gidx0 = Special("gidx0", 13)
|
||||
gidx1 = Special("gidx1", 13)
|
||||
ridx0 = Range(0, 4)
|
||||
alu0 = (gidx1+(ridx0*13))
|
||||
v0 = ((gidx0+11)%14).lt(11)
|
||||
v1 = ((alu0+((gidx0+39)//42))%14).lt(11)
|
||||
v2 = gidx0.lt(3)
|
||||
v3 = alu0.lt(42)
|
||||
|
||||
for v in itertools.permutations([v0,v1,v2,v3]):
|
||||
self.assertEqual(simplify_valid(v[0]&v[1]&v[2]&v[3]).render(), "False")
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def check(self, load, svalid, sidx0, sidx1):
|
||||
load = full_graph_rewrite(load.sink()).src[0]
|
||||
|
||||
@@ -992,10 +992,16 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
|
||||
|
||||
return uop
|
||||
|
||||
def _valid_priority(v: UOp, valids:List[UOp]):
|
||||
# we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified
|
||||
try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids)
|
||||
except ValueError: return 0
|
||||
|
||||
def simplify_valid(valid:UOp) -> Optional[UOp]:
|
||||
ret:List[UOp] = []
|
||||
something_changed = False
|
||||
for stmt in split_uop(valid, BinaryOps.AND):
|
||||
valids = list(split_uop(valid, Ops.AND))
|
||||
for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)):
|
||||
ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt)
|
||||
if ret[-1] is not stmt: something_changed = True
|
||||
return functools.reduce(operator.and_, ret) if something_changed else None
|
||||
|
||||
Reference in New Issue
Block a user