diff --git a/test/test_uops.py b/test/test_uops.py index b7392e46aa..24755f2e8c 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -389,12 +389,11 @@ class TestUOpMethod(unittest.TestCase): def test_uop_variables(self): a = UOp.variable("a", 1, 10) - uop_var = UOp.const(dtypes.int, a) - st_var = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), - ShapeTracker.from_shape((2, a)).to_uop())) - ast_vars = (st_var+uop_var).variables() - self.assertEqual(len(ast_vars), 1) - self.assertEqual(ast_vars[0], a) + uop_var = Tensor(a.bind(1)) + st_var = Tensor.empty((2, 1)).reshape((2, a.bind(1))) + _, var_vals = (uop_var+st_var).schedule_with_vars() + self.assertEqual(len(var_vals), 1) + self.assertEqual(list(var_vals)[0], a) def test_const_factor(self): gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 8)) diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index c078640aab..2dd103e975 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -88,11 +88,12 @@ class TestVerifyAST(unittest.TestCase): st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a) with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st) - def test_flat_const_always_valid(self): + def test_const_view_always_valid(self): buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - a = UOp.const(dtypes.int, 0).cast(dtypes.float) - st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a) - helper_test_verify_ast(st) + a = UOp.const(dtypes.int, 0).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg="CLANG"),), ShapeTracker.from_shape(())),)) + st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a.cast(dtypes.float)) + # lowerer asserts because it does not remove ShapeTracker on CONST(VIEW(DEVICE)) + with self.assertRaises(AssertionError): helper_test_verify_ast(st) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 048bbf2e65..7ce68968c0 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -697,7 +697,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) -> if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg)) # movementops are pushed to VIEW elif uop.op is Ops.VIEW: - assert len(uop.src) == 0, f"can't swizzle in kernel yet {uop}" + # NOTE: we disallow VIEW in the middle of the AST, if it has a DEVICE source it's fine + assert len(uop.src) == 0 or uop.src[0].op is Ops.DEVICE, f"can't swizzle in kernel yet {uop}" st = uop.arg # everything else inherits shape else: