one more line

This commit is contained in:
George Hotz
2022-07-03 17:28:42 -07:00
parent 02cd8510cb
commit 4d4ea47ca7

View File

@@ -170,6 +170,7 @@ class Conv2D(Function):
x, w = ctx.saved_tensors
C = ctx.C # conv args from the context
dx, dw = None, None
if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv)
xt = grad_output
if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides.
@@ -179,6 +180,7 @@ class Conv2D(Function):
wt = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W)).movement_op(MovementOps.PERMUTE, (0, 2, 1, 3, 4))
wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W)).movement_op(MovementOps.FLIP, (2, 3))
py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px
# TODO: move padding backsolver into get_conv_args with an output_shape parameter
py_ = x.shape[2] - xt.shape[2] + C.py
px_ = x.shape[3] - xt.shape[3] + C.px
Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=(px, px_, py, py_), groups=C.groups)
@@ -191,6 +193,6 @@ class Conv2D(Function):
py_ = (w.shape[2] - 1) * C.dy - xdw.shape[2] - C.py + C.sy * (grad_output_dw.shape[2]-1) + 1
px_ = (w.shape[3] - 1) * C.dx - xdw.shape[3] - C.px + C.sx * (grad_output_dw.shape[3]-1) + 1
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.px, px_, C.py, py_), stride=(C.dy, C.dx), dilation=(C.sy, C.sx), groups=C.groups)
grad_weight = ctx._conv(xdw, grad_output_dw, Cdw)
dw = grad_weight.movement_op(MovementOps.PERMUTE, (1,0,2,3))
dw = ctx._conv(xdw, grad_output_dw, Cdw).movement_op(MovementOps.PERMUTE, (1,0,2,3))
return dx, dw