From a1dfd288bb434fc684872d7cde8bbee11e0cd524 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 7 Nov 2024 20:27:56 -0500 Subject: [PATCH] different valid order (#7589) in simplify_valid, we start with valids that are in others' parent so the others is more likely to be simplified --- test/unit/test_simplify_valid_idx.py | 22 +++++++++++++++++++++- tinygrad/ops.py | 8 +++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 35b7881648..3630df4d3b 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -1,4 +1,4 @@ -import unittest +import unittest, itertools from typing import Tuple from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing @@ -79,6 +79,26 @@ class TestValidIdxSimplification(unittest.TestCase): valid = alu0.lt(57) & alu0.ge(1) self.assertIsNone(simplify_valid(valid)) + def test_valid_order_matters1(self): + ridx0 = Range(0, 2) + v0 = ridx0.lt(1) + v1 = ((ridx0*5+1)%6).lt(5) + self.assertEqual(simplify_valid(v0&v1).render(), "(ridx0<1)") + self.assertEqual(simplify_valid(v1&v0).render(), "(ridx0<1)") + + def test_valid_order_matters2(self): + gidx0 = Special("gidx0", 13) + gidx1 = Special("gidx1", 13) + ridx0 = Range(0, 4) + alu0 = (gidx1+(ridx0*13)) + v0 = ((gidx0+11)%14).lt(11) + v1 = ((alu0+((gidx0+39)//42))%14).lt(11) + v2 = gidx0.lt(3) + v3 = alu0.lt(42) + + for v in itertools.permutations([v0,v1,v2,v3]): + self.assertEqual(simplify_valid(v[0]&v[1]&v[2]&v[3]).render(), "False") + class TestImageSimplification(unittest.TestCase): def check(self, load, svalid, sidx0, sidx1): load = full_graph_rewrite(load.sink()).src[0] diff --git a/tinygrad/ops.py b/tinygrad/ops.py index cfde0f88d4..5e92d30df4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -992,10 +992,16 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: return uop +def _valid_priority(v: UOp, valids:List[UOp]): + # we want valid that's in other valids' parents to be first, so it's more likely the other valids get simplified + try: return sum(-1 if parse_valid(v)[0] in other.parents else 0 for other in valids) + except ValueError: return 0 + def simplify_valid(valid:UOp) -> Optional[UOp]: ret:List[UOp] = [] something_changed = False - for stmt in split_uop(valid, BinaryOps.AND): + valids = list(split_uop(valid, Ops.AND)) + for stmt in sorted(valids, key=lambda v: _valid_priority(v, valids)): ret.append(newstmt if ret and (newstmt:=uop_given_valid(functools.reduce(operator.and_, ret), stmt)) is not None else stmt) if ret[-1] is not stmt: something_changed = True return functools.reduce(operator.and_, ret) if something_changed else None