test fixups from unmasked valid deletion [pr] (#8776)

This commit is contained in:
qazal
2025-01-28 02:23:30 -05:00
committed by GitHub
parent ed672881b0
commit aefbc2637f
3 changed files with 12 additions and 11 deletions

View File

@@ -389,12 +389,11 @@ class TestUOpMethod(unittest.TestCase):
def test_uop_variables(self):
a = UOp.variable("a", 1, 10)
uop_var = UOp.const(dtypes.int, a)
st_var = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0),
ShapeTracker.from_shape((2, a)).to_uop()))
ast_vars = (st_var+uop_var).variables()
self.assertEqual(len(ast_vars), 1)
self.assertEqual(ast_vars[0], a)
uop_var = Tensor(a.bind(1))
st_var = Tensor.empty((2, 1)).reshape((2, a.bind(1)))
_, var_vals = (uop_var+st_var).schedule_with_vars()
self.assertEqual(len(var_vals), 1)
self.assertEqual(list(var_vals)[0], a)
def test_const_factor(self):
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 8))

View File

@@ -88,11 +88,12 @@ class TestVerifyAST(unittest.TestCase):
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st)
def test_flat_const_always_valid(self):
def test_const_view_always_valid(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
a = UOp.const(dtypes.int, 0).cast(dtypes.float)
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a)
helper_test_verify_ast(st)
a = UOp.const(dtypes.int, 0).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg="CLANG"),), ShapeTracker.from_shape(())),))
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a.cast(dtypes.float))
# lowerer asserts because it does not remove ShapeTracker on CONST(VIEW(DEVICE))
with self.assertRaises(AssertionError): helper_test_verify_ast(st)
if __name__ == '__main__':
unittest.main()

View File

@@ -697,7 +697,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) ->
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
# movementops are pushed to VIEW
elif uop.op is Ops.VIEW:
assert len(uop.src) == 0, f"can't swizzle in kernel yet {uop}"
# NOTE: we disallow VIEW in the middle of the AST, if it has a DEVICE source it's fine
assert len(uop.src) == 0 or uop.src[0].op is Ops.DEVICE, f"can't swizzle in kernel yet {uop}"
st = uop.arg
# everything else inherits shape
else: