diff --git a/test/test_schedule.py b/test/test_schedule.py index be8047d2cb..2eb57885de 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1726,6 +1726,14 @@ class TestIndexing(unittest.TestCase): self.check_schedule(out, 1) np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6) + def test_arange_index_shrink(self): + Tensor.manual_seed(0) + with Context(TRACK_MATCH_STATS=0): + x = Tensor.randn(11).realize() + a = Tensor.arange(22) + out = (x + a[:11]).sum() + self.check_schedule(out, 1) + @unittest.skip("TOOD: FUSE_ARANGE overrules Tensor.arange().contiguous()") def test_arange_index_contiguous(self): Tensor.manual_seed(0) diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 037349848a..54d0a26c69 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -308,7 +308,7 @@ view_left = merge_views+PatternMatcher([ def apply_swizzle(u:UOp) -> UOp: return graph_rewrite(u, view_left, name="Sub View Left") def swizzle_reduceop(r:UOp, src:UOp, view:UOp, fuse=False): - if (st:=unwrap(view.st)).contiguous: return None + if (st:=unwrap(view.st)).contiguous and st.size == r.size: return None input_st = ShapeTracker.from_shape(src.shape) tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg) prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])