rewrite some strideds into reshapes

This commit is contained in:
George Hotz
2022-10-30 16:31:27 -07:00
parent 8c849e637c
commit b7a115e5e5

View File

@@ -178,7 +178,7 @@ class LazyBuffer:
local_st = ShapeTracker(self.shape).movement_op(op, arg)
# instant nops
if local_st.contiguous and self.shape == local_st.shape:
if local_st.contiguous and self.shape == local_st.shape and op != MovementOps.STRIDED:
return self
# two ops in a row is one op. merge them if unresolved
@@ -195,6 +195,11 @@ class LazyBuffer:
if op == MovementOps.PERMUTE and local_st.contiguous:
return self.movement_op(MovementOps.RESHAPE, tuple(self.shape[i] for i in arg))
# some strideds are actually just reshapes
# NOTE: due to how strided works, we have to check the parent to be contiguous also
if op == MovementOps.STRIDED and local_st.contiguous and self.st.contiguous:
return self.movement_op(MovementOps.RESHAPE, tuple(i for i,_ in arg))
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]:
def replace_with_movement_op(y:Union[LazyOp, LazyBuffer]) -> LazyBuffer: