diff --git a/tinygrad/llops/ops_torch.py b/tinygrad/llops/ops_torch.py index c5bbc0ca72..1bef520ac1 100644 --- a/tinygrad/llops/ops_torch.py +++ b/tinygrad/llops/ops_torch.py @@ -32,13 +32,9 @@ def convdw(x,grad_output,dw,stride,groups): grad_output = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox).repeat(1, 1, C.cin, 1, 1) grad_output = grad_output.reshape(C.bs * C.groups * C.rcout * C.cin, 1, C.oy, C.ox) x = x.reshape(1, C.bs * C.groups * C.cin, C.iy, C.ix) - #print(input.shape, grad_output.shape) grad_weight = torch.conv2d(x, grad_output, dilation=stride, groups=C.bs*C.groups*C.cin) - grad_weight = grad_weight.reshape(C.bs, grad_weight.shape[1] // C.bs, *grad_weight.shape[2:]).sum(dim=0) - grad_weight = grad_weight.view(C.groups, C.cin, C.rcout, *grad_weight.shape[1:]).transpose(2, 1) - # narrow removes excess for strided - dw[:] = grad_weight.contiguous().view(C.groups*C.rcout, C.cin, *grad_weight.shape[3:]).narrow( - 2, 0, dw.shape[2]).narrow(3, 0, dw.shape[3]) + grad_weight = grad_weight.reshape(C.bs, C.groups, C.cin, C.rcout, *grad_weight.shape[2:]).sum(dim=0).transpose(2, 1) + dw[:] = grad_weight.reshape(C.groups*C.rcout, C.cin, *grad_weight.shape[3:])[:, :, :dw.shape[2], :dw.shape[3]] return dw def processing_op(op,x,w,ret,stride,groups):