diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1d48d1a787..9413e80b77 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -464,8 +464,8 @@ class Tensor: #x = x.reshape(bs, groups, cin, rcout, oy, ox, H, W).permute(0,1,3,4,5,2,6,7) # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) - ret = (x * weight.reshape(1, groups, rcout, *[1 for _ in range(len(oyx))], cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) - return ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))])) + ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) + return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW))) def dot(self, w:Tensor) -> Tensor: n1, n2 = len(self.shape), len(w.shape)