fix permute stacking

This commit is contained in:
George Hotz
2022-07-17 10:24:57 -07:00
parent c28a99087b
commit 77806e0d64

View File

@@ -239,7 +239,7 @@ class LazyBuffer:
if op == MovementOps.RESHAPE and x.realized is None and x.op.op == MovementOps.RESHAPE: return x.op.src[0].movement_op(op, arg)
# two permutes in a row is one permute
if op == MovementOps.PERMUTE and x.realized is None and x.op.op == MovementOps.PERMUTE: return x.op.src[0].movement_op(op, tuple(arg[i] for i in x.op.arg))
if op == MovementOps.PERMUTE and x.realized is None and x.op.op == MovementOps.PERMUTE: return x.op.src[0].movement_op(op, tuple(x.op.arg[i] for i in arg))
# TODO: SHUFFLE_SLICE_OPS is okay if it's a shrink
if (SHUFFLE_MOVEMENT_OPS or (SHUFFLE_RESHAPE_OPS and op == MovementOps.RESHAPE)) and x.optype == BinaryOps and x.realized is None and (SHUFFLE_SLICE_OPS or op != MovementOps.SLICE):