mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
test fixups from unmasked valid deletion [pr] (#8776)
This commit is contained in:
@@ -389,12 +389,11 @@ class TestUOpMethod(unittest.TestCase):
|
||||
|
||||
def test_uop_variables(self):
|
||||
a = UOp.variable("a", 1, 10)
|
||||
uop_var = UOp.const(dtypes.int, a)
|
||||
st_var = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0),
|
||||
ShapeTracker.from_shape((2, a)).to_uop()))
|
||||
ast_vars = (st_var+uop_var).variables()
|
||||
self.assertEqual(len(ast_vars), 1)
|
||||
self.assertEqual(ast_vars[0], a)
|
||||
uop_var = Tensor(a.bind(1))
|
||||
st_var = Tensor.empty((2, 1)).reshape((2, a.bind(1)))
|
||||
_, var_vals = (uop_var+st_var).schedule_with_vars()
|
||||
self.assertEqual(len(var_vals), 1)
|
||||
self.assertEqual(list(var_vals)[0], a)
|
||||
|
||||
def test_const_factor(self):
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 8))
|
||||
|
||||
@@ -88,11 +88,12 @@ class TestVerifyAST(unittest.TestCase):
|
||||
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
|
||||
with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st)
|
||||
|
||||
def test_flat_const_always_valid(self):
|
||||
def test_const_view_always_valid(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
a = UOp.const(dtypes.int, 0).cast(dtypes.float)
|
||||
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a)
|
||||
helper_test_verify_ast(st)
|
||||
a = UOp.const(dtypes.int, 0).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg="CLANG"),), ShapeTracker.from_shape(())),))
|
||||
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a.cast(dtypes.float))
|
||||
# lowerer asserts because it does not remove ShapeTracker on CONST(VIEW(DEVICE))
|
||||
with self.assertRaises(AssertionError): helper_test_verify_ast(st)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -697,7 +697,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) ->
|
||||
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
|
||||
# movementops are pushed to VIEW
|
||||
elif uop.op is Ops.VIEW:
|
||||
assert len(uop.src) == 0, f"can't swizzle in kernel yet {uop}"
|
||||
# NOTE: we disallow VIEW in the middle of the AST, if it has a DEVICE source it's fine
|
||||
assert len(uop.src) == 0 or uop.src[0].op is Ops.DEVICE, f"can't swizzle in kernel yet {uop}"
|
||||
st = uop.arg
|
||||
# everything else inherits shape
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user