mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix __getitem__
This commit is contained in:
@@ -385,7 +385,7 @@ class Tensor:
|
||||
# compute sum_dim, arange, and idx
|
||||
max_dim = max(i.ndim for i in idx)
|
||||
sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(tdim)]
|
||||
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
|
||||
arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)).contiguous() for n,(sd,d) in enumerate(zip(sum_dim, tdim))] # noqa: E501
|
||||
first_idx = [idx[0].reshape(*[1]*tdim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - tdim[0] - 1))]
|
||||
rest_idx = [i.reshape(*[1]*tdim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - tdim[0] - n)) for n,i in enumerate(idx[1:], 1)]
|
||||
reshaped_idx = first_idx + rest_idx
|
||||
|
||||
Reference in New Issue
Block a user