VALID early folding (#8100)

* fold valid

* :)

* fix test_verify_ast

* keep symbolic working
This commit is contained in:
qazal
2024-12-07 12:37:47 +02:00
committed by GitHub
parent 07b6d5cf63
commit 4074f52317
3 changed files with 8 additions and 4 deletions

View File

@@ -1905,7 +1905,7 @@ class TestView(unittest.TestCase):
b = a.pad(((0, 10), None))[10:]
sched = check_schedule(b.contiguous(), 1)
# TODO: this VALID can clean up, where do we need st?
self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape))
self.assertIs(store_val(sched[-1]), UOp.const(b.dtype, 0))
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)
@@ -1916,7 +1916,7 @@ class TestView(unittest.TestCase):
assert b.shape == (10, 10)
sched = check_schedule(b.contiguous(), 1)
self.assertEqual(sched[-1].ast.full_shape, (10, 10))
self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape))
self.assertIs(store_val(sched[-1]), UOp.const(b.dtype, 0))
run_schedule(sched)
np.testing.assert_equal(b.numpy(), 0)

View File

@@ -78,8 +78,9 @@ class TestVerifyAST(unittest.TestCase):
uop_sts = verify_ast(a.schedule()[-1].ast)
store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0]
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0]
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
with self.assertRaises(Exception, msg="unmasked valid folds"):
const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0]
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
def test_assert_swizzle(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)

View File

@@ -178,6 +178,9 @@ check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()),
to_si = PatternMatcher([
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
# unmasked VALID is just CONST
(UPat(Ops.VALID, name="valid").where(UPat.cvar("x"), UPat()),
lambda ctx,valid,x: x if all_int(valid.shape) and all(v.mask is None for v in valid.st.views) else None),
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda ctx,x: x),