From 8672a9db3f063de61c729cc08072a494eea8a0a1 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 31 Jul 2024 12:59:38 -0700 Subject: [PATCH] add test to validate lazyops dims (#5845) --- test/test_linearizer_failures.py | 2 +- tinygrad/ops.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 9b11292efc..4d2a1ee068 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -390,7 +390,7 @@ class TestLinearizerFailures(unittest.TestCase): for st in k.uops.sink.src: self.assertEqual(len(st.src), 4) self.assertLessEqual(len(ifs[0].src[0].sparents), 16) - @unittest.expectedFailure + @unittest.skip("this is an invalid lazyop") def test_failure_45(self): ast = LazyOp(MetaOps.KERNEL, arg=None, src=( LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), src=( diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 2526da552e..41f7e77402 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -165,4 +165,6 @@ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]: assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}" assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}" assert_valid(out, out.arg.st) + shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])] + assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}" return sts