From 4074f523179331d3a99dd55fa4ef14946563f246 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 7 Dec 2024 12:37:47 +0200 Subject: [PATCH] VALID early folding (#8100) * fold valid * :) * fix test_verify_ast * keep symbolic working --- test/test_schedule.py | 4 ++-- test/unit/test_verify_ast.py | 5 +++-- tinygrad/engine/schedule.py | 3 +++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 1f3abccb74..2cef884c7d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1905,7 +1905,7 @@ class TestView(unittest.TestCase): b = a.pad(((0, 10), None))[10:] sched = check_schedule(b.contiguous(), 1) # TODO: this VALID can clean up, where do we need st? - self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape)) + self.assertIs(store_val(sched[-1]), UOp.const(b.dtype, 0)) run_schedule(sched) np.testing.assert_equal(b.numpy(), 0) @@ -1916,7 +1916,7 @@ class TestView(unittest.TestCase): assert b.shape == (10, 10) sched = check_schedule(b.contiguous(), 1) self.assertEqual(sched[-1].ast.full_shape, (10, 10)) - self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape)) + self.assertIs(store_val(sched[-1]), UOp.const(b.dtype, 0)) run_schedule(sched) np.testing.assert_equal(b.numpy(), 0) diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index d38967a7c8..d08319a747 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -78,8 +78,9 @@ class TestVerifyAST(unittest.TestCase): uop_sts = verify_ast(a.schedule()[-1].ast) store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0] self.assertEqual(store_st, ShapeTracker.from_shape((4, 4))) - const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0] - self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4))) + with self.assertRaises(Exception, msg="unmasked valid folds"): + const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0] + self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4))) def test_assert_swizzle(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index a7054c5a8a..c35f0a299e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -178,6 +178,9 @@ check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), to_si = PatternMatcher([ (UPat(Ops.VIEW, name="x"), _append_st_vars), (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))), + # unmasked VALID is just CONST + (UPat(Ops.VALID, name="valid").where(UPat.cvar("x"), UPat()), + lambda ctx,valid,x: x if all_int(valid.shape) and all(v.mask is None for v in valid.st.views) else None), # don't need contiguous or assign anymore (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x), (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda ctx,x: x),