diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index f81e3b1383..71975c5c2c 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -250,7 +250,6 @@ class LazyBuffer: x = self if IMAGE >= 1: - x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix)) w = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W)) added_output_channels = 0 @@ -258,8 +257,10 @@ class LazyBuffer: if C.cin % 4 != 0 and not (C.cin == 1 and C.groups%4 == 0): to_add = 4 - (C.cin % 4) w = w.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(w.shape))]) + x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix)) x = x.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(x.shape))]) C = C._replace(cin = C.cin + to_add) + x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups*C.cin, C.iy, C.ix)) # hack for non multiples of 4 on C.rcout if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0): @@ -269,7 +270,7 @@ class LazyBuffer: # packed assert (C.groups*C.cin) % 4 == 0 - x = x.movement_op(MovementOps.PERMUTE, (0,3,4,1,2)) + x = x.movement_op(MovementOps.PERMUTE, (0,2,3,1)) x = x.movement_op(MovementOps.RESHAPE, (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4)) assert C.cout % 4 == 0