From 2a72b00679784ea0d8052cdf083c09f809b03a55 Mon Sep 17 00:00:00 2001 From: Shun Usami Date: Tue, 16 Sep 2025 11:57:25 -0700 Subject: [PATCH] Add test for 2D tensor indexing in setitem (#12193) * Add test for 2D tensor indexing in setitem * Fix _masked_setitem to handle multi dim indexing correctly * Fix indent * Add fuzz test for 3D tensor indexing in setitem * Skip indexing fuzz test (slow) --- test/test_setitem.py | 26 ++++++++++++++++++++++++++ tinygrad/tensor.py | 3 ++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/test/test_setitem.py b/test/test_setitem.py index 54ae9007af..2005b7c801 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -1,4 +1,6 @@ import unittest +import random +from os import getenv from tinygrad import Tensor, TinyJit, Variable, dtypes from tinygrad.helpers import Context import numpy as np @@ -176,6 +178,30 @@ class TestSetitem(unittest.TestCase): n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy() np.testing.assert_allclose(t.numpy(), n) + def test_setitem_2d_tensor_indexing(self): + t = Tensor.zeros(2).contiguous() + index = Tensor([[0, 1], [1,0]]) + v = Tensor.arange(2*2).reshape(2, 2).contiguous() + t[index] = v + n = np.zeros((2,)) + n[index.numpy()] = v.numpy() + np.testing.assert_allclose(t.numpy(), n) + + @unittest.skip("slow") + def test_setitem_tensor_indexing_fuzz(self): + random.seed(getenv("SEED", 42)) + for _ in range(getenv("ITERS", 100)): + size = random.randint(5, 10) + d0, d1, d2 = random.randint(1,5), random.randint(1,5), random.randint(1,5) + t = Tensor.zeros(size).contiguous() + n = np.zeros((size,)) + index = Tensor.randint((d0, d1, d2), low=0, high=size) + v = Tensor.arange(d0*d1*d2).reshape(d0, d1, d2) + t[index] = v + n[index.numpy()] = v.numpy() + np.testing.assert_allclose(t.numpy(), n, err_msg=f"failed with index={index.numpy().tolist()} and v={v.numpy().tolist()}") + + class TestWithGrad(unittest.TestCase): def test_no_requires_grad_works(self): z = Tensor.rand(8, 8) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e2978a64e9..20cb713fdb 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -98,7 +98,8 @@ def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]: def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor: # reduce such that if mask contains repeated indices the last one remains - for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim))) + for dim in reversed(axes): + mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim))) # remove extra dims from reduce for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim) # select from values for each True element in mask else select from target