mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
VALID early folding (#8100)
* fold valid * :) * fix test_verify_ast * keep symbolic working
This commit is contained in:
@@ -1905,7 +1905,7 @@ class TestView(unittest.TestCase):
|
||||
b = a.pad(((0, 10), None))[10:]
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
# TODO: this VALID can clean up, where do we need st?
|
||||
self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape))
|
||||
self.assertIs(store_val(sched[-1]), UOp.const(b.dtype, 0))
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(b.numpy(), 0)
|
||||
|
||||
@@ -1916,7 +1916,7 @@ class TestView(unittest.TestCase):
|
||||
assert b.shape == (10, 10)
|
||||
sched = check_schedule(b.contiguous(), 1)
|
||||
self.assertEqual(sched[-1].ast.full_shape, (10, 10))
|
||||
self.assertIs(store_val(sched[-1]), UOp.const_with_shape(b.dtype, 0, b.lazydata.st.shape))
|
||||
self.assertIs(store_val(sched[-1]), UOp.const(b.dtype, 0))
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(b.numpy(), 0)
|
||||
|
||||
|
||||
@@ -78,8 +78,9 @@ class TestVerifyAST(unittest.TestCase):
|
||||
uop_sts = verify_ast(a.schedule()[-1].ast)
|
||||
store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0]
|
||||
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0]
|
||||
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
|
||||
with self.assertRaises(Exception, msg="unmasked valid folds"):
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0]
|
||||
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
|
||||
|
||||
def test_assert_swizzle(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
|
||||
@@ -178,6 +178,9 @@ check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()),
|
||||
to_si = PatternMatcher([
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
|
||||
# unmasked VALID is just CONST
|
||||
(UPat(Ops.VALID, name="valid").where(UPat.cvar("x"), UPat()),
|
||||
lambda ctx,valid,x: x if all_int(valid.shape) and all(v.mask is None for v in valid.st.views) else None),
|
||||
# don't need contiguous or assign anymore
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda ctx,x: x),
|
||||
|
||||
Reference in New Issue
Block a user