diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 8494416ce5..a3100ee6dd 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -100,7 +100,7 @@ def convdw(x,grad_output,dw,stride,groups): return dw def convdx(w,grad_output,dx,stride,groups): - C = get_conv_args(x.shape, dw.shape, stride, groups) + C = get_conv_args(dx.shape, w.shape, stride, groups) ggg = grad_output.reshape(C.bs, C.groups, C.rcout, C.oy, C.ox) tw = w.reshape(C.groups, C.rcout, C.cin, C.H, C.W) gdx = dx.reshape((C.bs, C.groups, C.cin, C.iy, C.ix))