From b7a115e5e5bc626b339eb3bec6d3375db5e76f88 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 30 Oct 2022 16:31:27 -0700 Subject: [PATCH] rewrite some strideds into reshapes --- tinygrad/lazy.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 315c5f678d..3f904c3060 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -178,7 +178,7 @@ class LazyBuffer: local_st = ShapeTracker(self.shape).movement_op(op, arg) # instant nops - if local_st.contiguous and self.shape == local_st.shape: + if local_st.contiguous and self.shape == local_st.shape and op != MovementOps.STRIDED: return self # two ops in a row is one op. merge them if unresolved @@ -195,6 +195,11 @@ class LazyBuffer: if op == MovementOps.PERMUTE and local_st.contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg)) + # some strideds are actually just reshapes + # NOTE: due to how strided works, we have to check the parent to be contiguous also + if op == MovementOps.STRIDED and local_st.contiguous and self.st.contiguous: + return self.movement_op(MovementOps.RESHAPE, tuple(i for i,_ in arg)) + # if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]: def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer: