mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
cleanups for IMAGE=2 conv
This commit is contained in:
@@ -263,23 +263,17 @@ class LazyBuffer:
|
||||
# hack for non multiples of 4 on C.cin
|
||||
if C.cin % 4 != 0 and not (C.cin == 1 and C.groups%4 == 0):
|
||||
to_add = 4 - (C.cin % 4)
|
||||
ws = [(0, 0) for _ in w.shape]
|
||||
ws[2] = (0, to_add)
|
||||
w = w.movement_op(MovementOps.PAD, ws)
|
||||
w = w.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(w.shape))])
|
||||
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
xs = [(0, 0) for _ in x.shape]
|
||||
xs[2] = (0, to_add)
|
||||
x = x.movement_op(MovementOps.PAD, xs)
|
||||
x = x.movement_op(MovementOps.PAD, [(0, to_add) if i == 2 else (0, 0) for i in range(len(x.shape))])
|
||||
C = C._replace(cin = C.cin + to_add)
|
||||
x = x.movement_op(MovementOps.RESHAPE, (C.bs, C.groups*C.cin, C.iy, C.ix))
|
||||
|
||||
# hack for non multiples of 4 on C.rcout
|
||||
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
|
||||
added_output_channels = 4 - (C.rcout % 4)
|
||||
ws = [(0, 0) for _ in w.shape]
|
||||
ws[1] = (0, added_output_channels)
|
||||
w = w.movement_op(MovementOps.PAD, ws)
|
||||
w = w.movement_op(MovementOps.PAD, [(0, added_output_channels) if i == 1 else (0, 0) for i in range(len(w.shape))])
|
||||
C = C._replace(rcout = C.rcout + added_output_channels, cout = C.groups * (C.rcout + added_output_channels))
|
||||
else:
|
||||
added_output_channels = 0
|
||||
@@ -347,9 +341,7 @@ class LazyBuffer:
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.groups, C.rcout))
|
||||
xs = [(0, s) for s in ret.shape]
|
||||
xs[4] = (0, ret.shape[4]-added_output_channels)
|
||||
ret = ret.movement_op(MovementOps.SHRINK, xs)
|
||||
ret = ret.movement_op(MovementOps.SHRINK, [(0, s-added_output_channels) if i == 4 else (0, s) for i,s in enumerate(ret.shape)])
|
||||
C = C._replace(rcout = C.rcout - added_output_channels, cout = C.groups * (C.rcout - added_output_channels))
|
||||
|
||||
ret = ret.movement_op(MovementOps.RESHAPE, (C.bs, C.oy, C.ox, C.cout))
|
||||
|
||||
Reference in New Issue
Block a user