diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 6ffbb75a1e..917161b33c 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -192,10 +192,9 @@ class Conv2D(Function): def backward(ctx, grad_output): x, w, C = ctx.saved_tensors dx, dw = None, None - if ctx.needs_input_grad[0]: - #dx = ctx.processing_op(ProcessingOps.CONVT, grad_output, w, x.shape, C) if ctx.needs_input_grad[0] else None + if ctx.needs_input_grad[0]: # compute derivative of inputs using ProcessingOps.CONV (this is a transposed conv) xt = grad_output - if C.xs > 1 or C.ys > 1: # unstride. note, this is really memory intensive for big strides. + if C.xs > 1 or C.ys > 1: # unstride. NOTE: this is really memory intensive for big strides. xt = ctx.movement_op(MovementOps.RESHAPE, xt, (grad_output.shape[0], grad_output.shape[1], grad_output.shape[2], 1, grad_output.shape[3], 1)) xt = ctx.movement_op(MovementOps.SLICE, xt, ((0,xt.shape[0]), (0,xt.shape[1]), (0,xt.shape[2]), (0,C.ys), (0,xt.shape[4]), (0,C.xs))) xt = ctx.movement_op(MovementOps.RESHAPE, xt, (xt.shape[0], xt.shape[1], xt.shape[2]*C.ys, xt.shape[4]*C.xs)) @@ -209,16 +208,15 @@ class Conv2D(Function): Cdx = get_conv_args(xt.shape, wt.shape, dilation=(C.dy, C.dx), padding=(px, px_, py, py_), groups=C.groups) dx = ctx._conv(xt, wt, Cdx) - if ctx.needs_input_grad[1]: - # compute derivative of weights using ProcessingOps.CONV + if ctx.needs_input_grad[1]: # compute derivative of weights using ProcessingOps.CONV xdw = ctx.movement_op(MovementOps.RESHAPE, x, (C.bs, C.groups, C.cin, C.iy, C.ix)) xdw = ctx.movement_op(MovementOps.PERMUTE, xdw, (2,1,0,3,4)) xdw = ctx.movement_op(MovementOps.RESHAPE, xdw, (C.cin, C.groups*C.bs, C.iy, C.ix)) grad_output_dw = ctx.movement_op(MovementOps.PERMUTE, grad_output, (1,0,2,3)) grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.cout, C.bs, C.oy, C.ox)) - Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.px, C.px_, C.py, C.py_), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups) + py_ = (w.shape[2] - 1) * C.dy - xdw.shape[2] - C.py + C.ys * (grad_output_dw.shape[2]-1) + 1 + px_ = (w.shape[3] - 1) * C.dx - xdw.shape[3] - C.px + C.xs * (grad_output_dw.shape[3]-1) + 1 + Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.px, px_, C.py, py_), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.groups) grad_weight = ctx._conv(xdw, grad_output_dw, Cdw) - grad_weight = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3)) - # TODO: remove this slice using asymmetric padding - dw = ctx.movement_op(MovementOps.SLICE, grad_weight, ((0, grad_weight.shape[0]), (0, grad_weight.shape[1]), (0, w.shape[2]), (0, w.shape[3]))) + dw = ctx.movement_op(MovementOps.PERMUTE, grad_weight, (1,0,2,3)) return dx, dw