mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
base uop is always contiguous (#7907)
* base is always contiguous * add test_late_fusion_post_permute_simpler * Revert "swizzle tc [pr] (#7633)" This reverts commitf02462c5cb. * Revert "Revert "swizzle tc [pr] (#7633)"" This reverts commita26b577d86. * yay * minimal diff
This commit is contained in:
@@ -1882,6 +1882,19 @@ class TestSwizzle(unittest.TestCase):
|
||||
ret = swizzle_rewrite(reswizzle)
|
||||
self.assertIs(ret, reswizzle)
|
||||
|
||||
def test_late_fusion_post_permute_simpler(self):
|
||||
base = ShapeTracker.from_shape((32, 16, 1))
|
||||
start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
|
||||
r = start.view(start.st.expand((32, 16, 16))).r(Ops.ADD, (2,))
|
||||
add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1))
|
||||
self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1)))
|
||||
to_store = add.view(add.st.permute((1, 0, 2))).contiguous()
|
||||
self.assertEqual(to_store.st, ShapeTracker.from_shape((32, 16, 1)))
|
||||
self.assertEqual(to_store.src[0].st, add.st.permute((1, 0, 2)))
|
||||
self.assertIs(to_store.src[0].op, Ops.VIEW)
|
||||
ret = graph_rewrite(to_store, view_left)
|
||||
self.assertEqual(swizzle_cnt(ret), 1)
|
||||
|
||||
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
|
||||
class TestView(unittest.TestCase):
|
||||
def test_all_masked_out(self):
|
||||
|
||||
@@ -266,7 +266,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
src_sts = [x.st for x in self.src if x.st is not None]
|
||||
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is Ops.REDUCE_AXIS else src_sts[0]
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg) if self.op is Ops.REDUCE_AXIS else src_sts[0].shape)
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
|
||||
Reference in New Issue
Block a user