diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index 807b3d1a8b..3d5da85e76 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -27,13 +27,16 @@ from tinygrad.ops import ProcessingOps def processing_op(op,x,w,ret,C): stride, groups, dilation, padding = (C.ys, C.xs), C.groups, (C.dy, C.dx), (C.py, C.px) + # stride is the same as doing the full conv and slicing with stride at the end + # dilation is the same as conving with a weight matrix with 0s added if op == ProcessingOps.CONV: ret[:] = torch.conv2d(x, w, stride=stride, groups=groups, dilation=dilation, padding=padding) elif op == ProcessingOps.CONVT: if stride == (1,1) and dilation == (1,1): - # strided needs weird "undilation": https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + # strided needs weird "unstride": https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md + # it's 0 insertion between the inputs w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W) - ret[:] = torch.conv2d(x, w, padding=(C.H-C.py-1,C.W-C.px-1), groups=groups) + ret[:] = torch.conv2d(x, w, dilation=dilation, padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=groups) else: output_padding = [ret.shape[d+2] - ((x.shape[d+2] - padding[d]*2 - 1) * stride[d] + 1 + dilation[d] * (w.shape[d+2] - 1)) for d in range(2)] ret[:] = torch.conv_transpose2d(x, w, padding=padding, stride=stride, groups=groups, output_padding=output_padding, dilation=dilation) diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 41936c12db..6da4b115a7 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -182,6 +182,7 @@ class Conv2D(Function): # this expand is slow grad_output_dw = ctx.movement_op(MovementOps.EXPAND, grad_output_dw, (C.bs * C.groups, C.cin, C.rcout, C.oy, C.ox)) grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.bs * C.groups * C.cin * C.rcout, 1, C.oy, C.ox)) + # padding is the same, stride and dilation are flipped Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups*C.cin) grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (Cdw.bs, Cdw.cout, Cdw.oy, Cdw.ox), Cdw) grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.bs, C.groups, C.cin, C.rcout, Cdw.oy, Cdw.ox))