use conv2d for transpose when able

This commit is contained in:
George Hotz
2022-06-11 09:32:13 -07:00
parent 861323c121
commit 1d29780b75

View File

@@ -41,12 +41,14 @@ def processing_op(op,x,w,ret,stride,groups):
if op == ProcessingOps.CONV:
ret[:] = torch.conv2d(x, w, stride=stride, groups=groups)
elif op == ProcessingOps.CONVT:
output_padding = [ret.shape[d+2] - ((x.shape[d+2] - 1) * stride[d] + 1 + (w.shape[d+2] - 1)) for d in range(2)]
ret[:] = torch.conv_transpose2d(x, w, stride=stride, groups=groups, output_padding=output_padding)
# wrong, strided needs weird padding: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
#C = get_conv_args(ret.shape, w.shape, stride, groups)
#w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W)
#ret[:] = torch.conv2d(x, w, padding=(C.H-1,C.W-1), groups=groups)
if stride == 1 or stride == (1,1):
# strided needs weird padding: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
C = get_conv_args(ret.shape, w.shape, stride, groups)
w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W)
ret[:] = torch.conv2d(x, w, padding=(C.H-1,C.W-1), groups=groups)
else:
output_padding = [ret.shape[d+2] - ((x.shape[d+2] - 1) * stride[d] + 1 + (w.shape[d+2] - 1)) for d in range(2)]
ret[:] = torch.conv_transpose2d(x, w, stride=stride, groups=groups, output_padding=output_padding)
elif op == ProcessingOps.CONVDW:
convdw(x,w,ret,stride,groups)
return ret