From 4d4ea47ca71533192a174e00c863b4ea58be35cf Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 3 Jul 2022 17:28:42 -0700 Subject: [PATCH] one more line --- tinygrad/mlops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 319ea7aa60..c36956a485 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -170,6 +170,7 @@ class Conv2D(Function): x, w = ctx.saved_tensors C = ctx.C # conv args from the context dx, dw = None, None + if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv) xt = grad_output if C.sx > 1 or C.sy > 1: # unstride. NOTE: this is really memory intensive for big strides. @@ -179,6 +180,7 @@ class Conv2D(Function): wt = w.movement_op(MovementOps.RESHAPE, (C.groups, C.rcout, C.cin, C.H, C.W)).movement_op(MovementOps.PERMUTE, (0, 2, 1, 3, 4)) wt = wt.movement_op(MovementOps.RESHAPE, (C.groups*C.cin, C.rcout, C.H, C.W)).movement_op(MovementOps.FLIP, (2, 3)) py, px = (C.H-1)*C.dy - C.py, (C.W-1)*C.dx - C.px + # TODO: move padding backsolver into get_conv_args with an output_shape parameter py_ = x.shape[2] - xt.shape[2] + C.py px_ = x.shape[3] - xt.shape[3] + C.px Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=(px, px_, py, py_), groups=C.groups) @@ -191,6 +193,6 @@ class Conv2D(Function): py_ = (w.shape[2] - 1) * C.dy - xdw.shape[2] - C.py + C.sy * (grad_output_dw.shape[2]-1) + 1 px_ = (w.shape[3] - 1) * C.dx - xdw.shape[3] - C.px + C.sx * (grad_output_dw.shape[3]-1) + 1 Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.px, px_, C.py, py_), stride=(C.dy, C.dx), dilation=(C.sy, C.sx), groups=C.groups) - grad_weight = ctx._conv(xdw, grad_output_dw, Cdw) - dw = grad_weight.movement_op(MovementOps.PERMUTE, (1,0,2,3)) + dw = ctx._conv(xdw, grad_output_dw, Cdw).movement_op(MovementOps.PERMUTE, (1,0,2,3)) + return dx, dw