diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 357bd3ebea..13d955a85e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -385,7 +385,7 @@ class Tensor: # compute sum_dim, arange, and idx max_dim = max(i.ndim for i in idx) sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)] - arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501 + arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)).contiguous() for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501 first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))] rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)] reshaped_idx = first_idx + rest_idx