mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
this fixes 2 of the conv recomputes...but it's ugh
This commit is contained in:
@@ -4,7 +4,6 @@ from tinygrad.ops import MovementOps
|
||||
# dweight format is oc//4 x ch, cw x 4(oc)
|
||||
# weight format is oc//4 x ch, ic//4, cw, 4(oc) x 4(ic)
|
||||
def preprocessing_op(ctx,x,w,C):
|
||||
x = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
w = ctx.movement_op(MovementOps.RESHAPE, w, (C.groups, C.rcout, C.cin, C.H, C.W))
|
||||
#print(x.shape, w.shape)
|
||||
|
||||
@@ -12,7 +11,7 @@ def preprocessing_op(ctx,x,w,C):
|
||||
# explictly add y-padding for batched inputs
|
||||
# N C H W
|
||||
xs = [(0, s) for s in x.shape]
|
||||
xs[3] = (-C.py, x.shape[3]+C.py)
|
||||
xs[2] = (-C.py, x.shape[2]+C.py)
|
||||
x = ctx.movement_op(MovementOps.SLICE, x, xs)
|
||||
C = C._replace(iy=C.iy + C.py*2, py=0)
|
||||
|
||||
@@ -23,10 +22,12 @@ def preprocessing_op(ctx,x,w,C):
|
||||
ws[2] = (0, w.shape[2]+to_add)
|
||||
w = ctx.movement_op(MovementOps.SLICE, w, ws)
|
||||
|
||||
x = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix))
|
||||
xs = [(0, s) for s in x.shape]
|
||||
xs[2] = (0, x.shape[2]+to_add)
|
||||
x = ctx.movement_op(MovementOps.SLICE, x, xs)
|
||||
C = C._replace(cin = C.cin + to_add)
|
||||
x = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups*C.cin, C.iy, C.ix))
|
||||
|
||||
# hack for non multiples of 4 on C.rcout
|
||||
if C.rcout % 4 != 0 and not (C.rcout == 1 and C.groups%4 == 0):
|
||||
@@ -39,7 +40,7 @@ def preprocessing_op(ctx,x,w,C):
|
||||
# packed
|
||||
assert (C.groups*C.cin) % 4 == 0
|
||||
#print(x.shape)
|
||||
x = ctx.movement_op(MovementOps.PERMUTE, x, (0,3,4,1,2))
|
||||
x = ctx.movement_op(MovementOps.PERMUTE, x, (0,2,3,1))
|
||||
x = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs*C.iy, C.ix*C.groups*C.cin//4, 4))
|
||||
|
||||
assert C.cout % 4 == 0
|
||||
|
||||
Reference in New Issue
Block a user