base uop is always contiguous (#7907)

* base is always contiguous

* add test_late_fusion_post_permute_simpler

* Revert "swizzle tc [pr] (#7633)"

This reverts commit f02462c5cb.

* Revert "Revert "swizzle tc [pr] (#7633)""

This reverts commit a26b577d86.

* yay

* minimal diff
This commit is contained in:
qazal
2024-11-26 07:13:29 -05:00
committed by GitHub
parent ceda43ce75
commit ea57c52b99
2 changed files with 14 additions and 1 deletions

View File

@@ -1882,6 +1882,19 @@ class TestSwizzle(unittest.TestCase):
ret = swizzle_rewrite(reswizzle)
self.assertIs(ret, reswizzle)
def test_late_fusion_post_permute_simpler(self):
base = ShapeTracker.from_shape((32, 16, 1))
start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
r = start.view(start.st.expand((32, 16, 16))).r(Ops.ADD, (2,))
add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1))
self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1)))
to_store = add.view(add.st.permute((1, 0, 2))).contiguous()
self.assertEqual(to_store.st, ShapeTracker.from_shape((32, 16, 1)))
self.assertEqual(to_store.src[0].st, add.st.permute((1, 0, 2)))
self.assertIs(to_store.src[0].op, Ops.VIEW)
ret = graph_rewrite(to_store, view_left)
self.assertEqual(swizzle_cnt(ret), 1)
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
class TestView(unittest.TestCase):
def test_all_masked_out(self):

View File

@@ -266,7 +266,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
src_sts = [x.st for x in self.src if x.st is not None]
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
from tinygrad.shape.shapetracker import ShapeTracker
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is Ops.REDUCE_AXIS else src_sts[0]
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape)
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))