this fixes 2 of the conv recomputes...but it's ugh

This commit is contained in:
George Hotz
2022-06-22 08:18:12 -07:00
parent b2d5df6049
commit 73415e20ab

View File

@@ -4,7 +4,6 @@ from tinygrad.ops import MovementOps
# dweight format is oc//4 x ch, cw x 4(oc)
# weight format is oc//4 x ch, ic//4, cw, 4(oc) x 4(ic)
def preprocessing_op(ctx,x,w,C):
x = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix))
w = ctx.movement_op(MovementOps.RESHAPE, w, (C.groups, C.rcout, C.cin, C.H, C.W))
#print(x.shape, w.shape)
@@ -12,7 +11,7 @@ def preprocessing_op(ctx,x,w,C):
# explictly add y-padding for batched inputs
# N C H W
xs = [(0, s) for s in x.shape]
xs[3] = (-C.py, x.shape[3]+C.py)
xs[2] = (-C.py, x.shape[2]+C.py)
x = ctx.movement_op(MovementOps.SLICE, x, xs)
C = C._replace(iy=C.iy + C.py*2, py=0)
@@ -23,10 +22,12 @@ def preprocessing_op(ctx,x,w,C):
ws[2] = (0, w.shape[2]+to_add)
w = ctx.movement_op(MovementOps.SLICE, w, ws)
x = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix))
xs = [(0, s) for s in x.shape]
xs[2] = (0, x.shape[2]+to_add)
x = ctx.movement_op(MovementOps.SLICE, x, xs)
C = C._replace(cin = C.cin + to_add)
x = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups*C.cin, C.iy, C.ix))
# hack for non multiples of 4 on C.rcout
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
@@ -39,7 +40,7 @@ def preprocessing_op(ctx,x,w,C):
# packed
assert (C.groups*C.cin) % 4 == 0
#print(x.shape)
x = ctx.movement_op(MovementOps.PERMUTE, x, (0,3,4,1,2))
x = ctx.movement_op(MovementOps.PERMUTE, x, (0,2,3,1))
x = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4))
assert C.cout % 4 == 0