mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
use conv2d for transpose when able
This commit is contained in:
@@ -41,12 +41,14 @@ def processing_op(op,x,w,ret,stride,groups):
|
||||
if op == ProcessingOps.CONV:
|
||||
ret[:] = torch.conv2d(x, w, stride=stride, groups=groups)
|
||||
elif op == ProcessingOps.CONVT:
|
||||
output_padding = [ret.shape[d+2] - ((x.shape[d+2] - 1) * stride[d] + 1 + (w.shape[d+2] - 1)) for d in range(2)]
|
||||
ret[:] = torch.conv_transpose2d(x, w, stride=stride, groups=groups, output_padding=output_padding)
|
||||
# wrong, strided needs weird padding: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||
#C = get_conv_args(ret.shape, w.shape, stride, groups)
|
||||
#w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W)
|
||||
#ret[:] = torch.conv2d(x, w, padding=(C.H-1,C.W-1), groups=groups)
|
||||
if stride == 1 or stride == (1,1):
|
||||
# strided needs weird padding: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||
C = get_conv_args(ret.shape, w.shape, stride, groups)
|
||||
w = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W).flip(3, 4).transpose(2, 1).reshape(C.groups*C.cin, C.rcout, C.H, C.W)
|
||||
ret[:] = torch.conv2d(x, w, padding=(C.H-1,C.W-1), groups=groups)
|
||||
else:
|
||||
output_padding = [ret.shape[d+2] - ((x.shape[d+2] - 1) * stride[d] + 1 + (w.shape[d+2] - 1)) for d in range(2)]
|
||||
ret[:] = torch.conv_transpose2d(x, w, stride=stride, groups=groups, output_padding=output_padding)
|
||||
elif op == ProcessingOps.CONVDW:
|
||||
convdw(x,w,ret,stride,groups)
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user