mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
one more line
This commit is contained in:
@@ -170,6 +170,7 @@ class Conv2D(Function):
|
||||
x, w = ctx.saved_tensors
|
||||
C = ctx.C # conv args from the context
|
||||
dx, dw = None, None
|
||||
|
||||
if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv)
|
||||
xt = grad_output
|
||||
if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides.
|
||||
@@ -179,6 +180,7 @@ class Conv2D(Function):
|
||||
wt = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W)).movement_op(MovementOps.PERMUTE, (0, 2, 1, 3, 4))
|
||||
wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W)).movement_op(MovementOps.FLIP, (2, 3))
|
||||
py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px
|
||||
# TODO: move padding backsolver into get_conv_args with an output_shape parameter
|
||||
py_ = x.shape[2] - xt.shape[2] + C.py
|
||||
px_ = x.shape[3] - xt.shape[3] + C.px
|
||||
Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=(px, px_, py, py_), groups=C.groups)
|
||||
@@ -191,6 +193,6 @@ class Conv2D(Function):
|
||||
py_ = (w.shape[2] - 1) * C.dy - xdw.shape[2] - C.py + C.sy * (grad_output_dw.shape[2]-1) + 1
|
||||
px_ = (w.shape[3] - 1) * C.dx - xdw.shape[3] - C.px + C.sx * (grad_output_dw.shape[3]-1) + 1
|
||||
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.px, px_, C.py, py_), stride=(C.dy, C.dx), dilation=(C.sy, C.sx), groups=C.groups)
|
||||
grad_weight = ctx._conv(xdw, grad_output_dw, Cdw)
|
||||
dw = grad_weight.movement_op(MovementOps.PERMUTE, (1,0,2,3))
|
||||
dw = ctx._conv(xdw, grad_output_dw, Cdw).movement_op(MovementOps.PERMUTE, (1,0,2,3))
|
||||
|
||||
return dx, dw
|
||||
|
||||
Reference in New Issue
Block a user