Add test for 2D tensor indexing in setitem (#12193)

* Add test for 2D tensor indexing in setitem

* Fix _masked_setitem to handle multi dim indexing correctly

* Fix indent

* Add fuzz test for 3D tensor indexing in setitem

* Skip indexing fuzz test (slow)
This commit is contained in:
Shun Usami
2025-09-16 11:57:25 -07:00
committed by GitHub
parent c7b03457d7
commit 2a72b00679
2 changed files with 28 additions and 1 deletions

View File

@@ -1,4 +1,6 @@
import unittest
import random
from os import getenv
from tinygrad import Tensor, TinyJit, Variable, dtypes
from tinygrad.helpers import Context
import numpy as np
@@ -176,6 +178,30 @@ class TestSetitem(unittest.TestCase):
n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy()
np.testing.assert_allclose(t.numpy(), n)
def test_setitem_2d_tensor_indexing(self):
t = Tensor.zeros(2).contiguous()
index = Tensor([[0, 1], [1,0]])
v = Tensor.arange(2*2).reshape(2, 2).contiguous()
t[index] = v
n = np.zeros((2,))
n[index.numpy()] = v.numpy()
np.testing.assert_allclose(t.numpy(), n)
@unittest.skip("slow")
def test_setitem_tensor_indexing_fuzz(self):
random.seed(getenv("SEED", 42))
for _ in range(getenv("ITERS", 100)):
size = random.randint(5, 10)
d0, d1, d2 = random.randint(1,5), random.randint(1,5), random.randint(1,5)
t = Tensor.zeros(size).contiguous()
n = np.zeros((size,))
index = Tensor.randint((d0, d1, d2), low=0, high=size)
v = Tensor.arange(d0*d1*d2).reshape(d0, d1, d2)
t[index] = v
n[index.numpy()] = v.numpy()
np.testing.assert_allclose(t.numpy(), n, err_msg=f"failed with index={index.numpy().tolist()} and v={v.numpy().tolist()}")
class TestWithGrad(unittest.TestCase):
def test_no_requires_grad_works(self):
z = Tensor.rand(8, 8)

View File

@@ -98,7 +98,8 @@ def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor:
# reduce such that if mask contains repeated indices the last one remains
for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
for dim in reversed(axes):
mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
# remove extra dims from reduce
for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
# select from values for each True element in mask else select from target