transpose dilation was simple

This commit is contained in:
George Hotz
2022-06-15 15:20:51 -07:00
parent 2a14befb74
commit 86f55b078d
2 changed files with 6 additions and 2 deletions

View File

@@ -27,13 +27,16 @@ from tinygrad.ops import ProcessingOps
def processing_op(op,x,w,ret,C):
stride, groups, dilation, padding = (C.ys, C.xs), C.groups, (C.dy, C.dx), (C.py, C.px)
# stride is the same as doing the full conv and slicing with stride at the end
# dilation is the same as conving with a weight matrix with 0s added
if op == ProcessingOps.CONV:
ret[:] = torch.conv2d(x, w, stride=stride, groups=groups, dilation=dilation, padding=padding)
elif op == ProcessingOps.CONVT:
if stride == (1,1) and dilation == (1,1):
# strided needs weird "undilation": https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
# strided needs weird "unstride": https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
# it's 0 insertion between the inputs
w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W)
ret[:] = torch.conv2d(x, w, padding=(C.H-C.py-1,C.W-C.px-1), groups=groups)
ret[:] = torch.conv2d(x, w, dilation=dilation, padding=((C.H-1)*C.dy-C.py,(C.W-1)*C.dx-C.px), groups=groups)
else:
output_padding = [ret.shape[d+2] - ((x.shape[d+2] - padding[d]*2 - 1) * stride[d] + 1 + dilation[d] * (w.shape[d+2] - 1)) for d in range(2)]
ret[:] = torch.conv_transpose2d(x, w, padding=padding, stride=stride, groups=groups, output_padding=output_padding, dilation=dilation)

View File

@@ -182,6 +182,7 @@ class Conv2D(Function):
# this expand is slow
grad_output_dw = ctx.movement_op(MovementOps.EXPAND, grad_output_dw, (C.bs * C.groups, C.cin, C.rcout, C.oy, C.ox))
grad_output_dw = ctx.movement_op(MovementOps.RESHAPE, grad_output_dw, (C.bs * C.groups * C.cin * C.rcout, 1, C.oy, C.ox))
# padding is the same, stride and dilation are flipped
Cdw = get_conv_args(xdw.shape, grad_output_dw.shape, padding=(C.py, C.px), stride=(C.dy, C.dx), dilation=(C.ys, C.xs), groups=C.bs*C.groups*C.cin)
grad_weight = ctx.processing_op(ProcessingOps.CONV, xdw, grad_output_dw, (Cdw.bs, Cdw.cout, Cdw.oy, Cdw.ox), Cdw)
grad_weight = ctx.movement_op(MovementOps.RESHAPE, grad_weight, (C.bs, C.groups, C.cin, C.rcout, Cdw.oy, Cdw.ox))