diff --git a/test/test_schedule.py b/test/test_schedule.py index 42c6d22f09..95bdda3129 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1882,6 +1882,19 @@ class TestSwizzle(unittest.TestCase): ret = swizzle_rewrite(reswizzle) self.assertIs(ret, reswizzle) + def test_late_fusion_post_permute_simpler(self): + base = ShapeTracker.from_shape((32, 16, 1)) + start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop())) + r = start.view(start.st.expand((32, 16, 16))).r(Ops.ADD, (2,)) + add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1)) + self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1))) + to_store = add.view(add.st.permute((1, 0, 2))).contiguous() + self.assertEqual(to_store.st, ShapeTracker.from_shape((32, 16, 1))) + self.assertEqual(to_store.src[0].st, add.st.permute((1, 0, 2))) + self.assertIs(to_store.src[0].op, Ops.VIEW) + ret = graph_rewrite(to_store, view_left) + self.assertEqual(swizzle_cnt(ret), 1) + def store_val(si:ScheduleItem): return si.ast.src[0].src[2] class TestView(unittest.TestCase): def test_all_masked_out(self): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6ff58105a2..68308d0bae 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -266,7 +266,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): src_sts = [x.st for x in self.src if x.st is not None] assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}" from tinygrad.shape.shapetracker import ShapeTracker - return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is Ops.REDUCE_AXIS else src_sts[0] + return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape) @functools.cached_property def full_shape(self) -> Tuple[sint, ...]: return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))