where fold try 2 (#3748)

* where fold try 2

* assign fold

* test_where_fold works

* add gated store support to ops_python

---------

Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
This commit is contained in:
George Hotz
2024-03-15 07:46:26 -07:00
committed by GitHub
parent 6b8c66e04f
commit ca19eb3e82
4 changed files with 39 additions and 7 deletions

View File

@@ -198,6 +198,25 @@ class TestLinearizer(unittest.TestCase):
lin = Linearizer(*sched[0].ast)
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
def test_assign_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(a+m)
a.realize()
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
def test_where_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
a.assign(b.where(2, a))
sched = create_schedule([a.lazydata])
assert len(sched) == 1
lin = Linearizer(*sched[-1].ast)
lin.hand_coded_optimizations()
lin.linearize()
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
def test_simplify_uop(self):
def helper_test_simplify(uop, dtype, vin, arg=None):
ast = LazyOp(BufferOps.CONST, (),