add UOps.VALID (#6387)

* uops valid

* broke full_shape

* fixup that st (hardcoded asts still red)

* fixup DEFINE_VAR

debug

more debug

* start moving stuff to ast_const

* move test_linearizer

* move test_linearizer_failures to ast_const

* fixup test_schedule

* small diff change

* regenerate dataset

* fixup test_multitensor

* regen dataset try 2

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
George Hotz
2024-09-09 16:58:43 +08:00
committed by GitHub
parent e1d61b048b
commit 8186e4e7d6
9 changed files with 23 additions and 20 deletions

View File

@@ -79,7 +79,7 @@ class TestVerifyAST(unittest.TestCase):
uop_sts = verify_ast(a.schedule()[-1].ast)
store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0]
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
const_st = [st for u,st in uop_sts.items() if u.op is UOps.CONST][0]
const_st = [st for u,st in uop_sts.items() if u.op is UOps.VALID][0]
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
if __name__ == '__main__':