mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
fix strided conv
This commit is contained in:
@@ -30,14 +30,17 @@ def conv(x,w,ret,stride,groups):
|
||||
|
||||
def convdw(input,grad_output,dw,stride,groups):
|
||||
# NOTE: torch.nn.grad.conv2d_weight is wrong for groups in pytorch, wonder who it affects
|
||||
# https://github.com/pytorch/pytorch/issues/51430
|
||||
C = get_conv_args(input.shape, dw.shape, stride, groups)
|
||||
grad_output = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox).repeat(1, 1, C.cin, 1, 1)
|
||||
grad_output = grad_output.reshape(C.bs * C.groups * C.rcout * C.cin, 1, C.oy, C.ox)
|
||||
input = input.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix)
|
||||
grad_weight = torch.nn.functional.conv2d(input, grad_output, dilation=stride, groups=C.bs*C.groups*C.cin)
|
||||
grad_weight = grad_weight.reshape(C.bs,-1).sum(dim=0)
|
||||
grad_weight = grad_weight.view(C.groups, C.cin, C.rcout, C.H, C.W).transpose(2, 1)
|
||||
dw[:] = grad_weight.contiguous().view(C.groups*C.rcout, C.cin, C.H, C.W)
|
||||
grad_weight = grad_weight.reshape(C.bs, grad_weight.shape[1] // C.bs, *grad_weight.shape[2:]).sum(dim=0)
|
||||
grad_weight = grad_weight.view(C.groups, C.cin, C.rcout, *grad_weight.shape[1:]).transpose(2, 1)
|
||||
# narrow removes excess for strided
|
||||
dw[:] = grad_weight.contiguous().view(C.groups*C.rcout, C.cin, *grad_weight.shape[3:]).narrow(
|
||||
2, 0, dw.shape[2]).narrow(3, 0, dw.shape[3])
|
||||
return dw
|
||||
|
||||
def convdx(w,grad_output,dx,stride,groups):
|
||||
|
||||
Reference in New Issue
Block a user