fix __getitem__

This commit is contained in:
qazal
2023-12-23 13:50:29 +02:00
parent b66a06ba67
commit 62ad719bfa

View File

@@ -385,7 +385,7 @@ class Tensor:
# compute sum_dim, arange, and idx
max_dim = max(i.ndim for i in idx)
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)]
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)).contiguous() for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))]
rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)]
reshaped_idx = first_idx + rest_idx