diff --git a/test/test_schedule.py b/test/test_schedule.py index fd32f8c543..5ac76daf16 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -8,7 +8,7 @@ from typing import List, Optional, Union, cast from tinygrad import nn, dtypes from tinygrad.device import Device from tinygrad.tensor import Tensor -from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps +from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps, verify_lazyop from tinygrad.helpers import DEBUG, FUSE_ARANGE, flatten, getenv from tinygrad.codegen.kernel import Kernel from tinygrad.engine.schedule import create_schedule @@ -1270,9 +1270,10 @@ class TestIndexing(unittest.TestCase): def check_schedule(self, xt:Tensor, cnt:int): with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)): s = xt.schedule() - kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL]) + kernels = [si for si in s if si.ast.op is MetaOps.KERNEL] + for si in kernels: verify_lazyop(si.ast) run_schedule(s) - if FUSE_ARANGE: self.assertEqual(kernel_cnt, cnt) + if FUSE_ARANGE: self.assertEqual(len(kernels), cnt) def test_simple_indexing(self): X = Tensor.randn(10, 10).realize() @@ -1281,11 +1282,10 @@ class TestIndexing(unittest.TestCase): self.check_schedule(xt, 2) np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()]) - @unittest.expectedFailure def test_simple_indexing_alt(self): X = Tensor.arange(16).reshape(4, 4) xt = X[[1, 2], [1, 2]] - self.check_schedule(xt, 5) + self.check_schedule(xt, 3) np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]]) @unittest.expectedFailure @@ -1302,11 +1302,10 @@ class TestIndexing(unittest.TestCase): self.check_schedule(xt, 6) np.testing.assert_equal(xt.numpy(), 6) - @unittest.expectedFailure def test_advanced_simple_indexing_combined(self): X = Tensor.arange(16).reshape(4, 4) xt = X[1:2, [1, 2]] - self.check_schedule(xt, 4) + self.check_schedule(xt, 2) np.testing.assert_equal(xt.numpy(), np.arange(16).reshape(4, 4)[1:2, [1, 2]]) def test_push_through_reshape(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index a0dd779541..b034ad793c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -131,10 +131,18 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]): rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,)))) return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, [] if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, [] - # unify the kernel dims + # push through all movementops between reduceops reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {} seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {} for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops) + # pad all reduceops to the max of each dimension + shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])] + for i,dims in enumerate(shape_dims): + if len(dims) == 1 or (len(dims) == 2 and dims[0] == 1): continue + for (r,view),(input_st,axis) in reduce_info.items(): + if (dim:=input_st.shape[i]) > 1 and dim != max(dims): + input_st = input_st.pad(((0, 0),)*i+((0, max(dims)-dim),)) + reduce_info[(r, view)] = (input_st, axis) # create the stores var_vals = merge_dicts([out.st.var_vals.copy() for out in outs]) assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}