mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 06:48:22 -05:00
simplify fancy index with negative Tensor entries (#2749)
This commit is contained in:
@@ -385,7 +385,8 @@ class Tensor:
|
||||
for tensor_dim in type_dim[Tensor]:
|
||||
dims_collapsed_, dims_injected = sum(1 for d in dims_collapsed if tensor_dim >= d), sum(1 for d in type_dim[None] if tensor_dim >= d)
|
||||
tdim.append(td := tensor_dim - dims_collapsed_ + dims_injected)
|
||||
idx.append((t := indices[tensor_dim + dims_injected]).sign().__neg__().relu() * ret.shape[td] + t) # normalize the negative tensor indices
|
||||
# normalize the negative tensor indices
|
||||
idx.append(((t := indices[tensor_dim + dims_injected]) < 0).where(ret.shape[td], 0) + t)
|
||||
|
||||
# compute sum_dim, arange, and idx
|
||||
max_dim = max(i.ndim for i in idx)
|
||||
|
||||
Reference in New Issue
Block a user