From 22feb7330eb30fc84a971ddd329bb8993b1d2892 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 13 Dec 2023 14:45:50 -0500 Subject: [PATCH] simplify fancy index with negative Tensor entries (#2749) --- tinygrad/tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ff3b69ddaf..2ed79b8c57 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -385,7 +385,8 @@ class Tensor: for tensor_dim in type_dim[Tensor]: dims_collapsed_, dims_injected = sum(1 for d in dims_collapsed if tensor_dim >= d), sum(1 for d in type_dim[None] if tensor_dim >= d) tdim.append(td := tensor_dim - dims_collapsed_ + dims_injected) - idx.append((t := indices[tensor_dim + dims_injected]).sign().__neg__().relu() * ret.shape[td] + t) # normalize the negative tensor indices + # normalize the negative tensor indices + idx.append(((t := indices[tensor_dim + dims_injected]) < 0).where(ret.shape[td], 0) + t) # compute sum_dim, arange, and idx max_dim = max(i.ndim for i in idx)