simplify fancy index with negative Tensor entries (#2749)

This commit is contained in:
chenyu
2023-12-13 14:45:50 -05:00
committed by GitHub
parent b229879613
commit 22feb7330e

View File

@@ -385,7 +385,8 @@ class Tensor:
for tensor_dim in type_dim[Tensor]:
dims_collapsed_, dims_injected = sum(1 for d in dims_collapsed if tensor_dim >= d), sum(1 for d in type_dim[None] if tensor_dim >= d)
tdim.append(td := tensor_dim - dims_collapsed_ + dims_injected)
idx.append((t := indices[tensor_dim + dims_injected]).sign().__neg__().relu() * ret.shape[td] + t) # normalize the negative tensor indices
# normalize the negative tensor indices
idx.append(((t := indices[tensor_dim + dims_injected]) < 0).where(ret.shape[td], 0) + t)
# compute sum_dim, arange, and idx
max_dim = max(i.ndim for i in idx)