mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Add test for 2D tensor indexing in setitem (#12193)
* Add test for 2D tensor indexing in setitem * Fix _masked_setitem to handle multi dim indexing correctly * Fix indent * Add fuzz test for 3D tensor indexing in setitem * Skip indexing fuzz test (slow)
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import unittest
|
||||
import random
|
||||
from os import getenv
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
||||
from tinygrad.helpers import Context
|
||||
import numpy as np
|
||||
@@ -176,6 +178,30 @@ class TestSetitem(unittest.TestCase):
|
||||
n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy()
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
|
||||
def test_setitem_2d_tensor_indexing(self):
|
||||
t = Tensor.zeros(2).contiguous()
|
||||
index = Tensor([[0, 1], [1,0]])
|
||||
v = Tensor.arange(2*2).reshape(2, 2).contiguous()
|
||||
t[index] = v
|
||||
n = np.zeros((2,))
|
||||
n[index.numpy()] = v.numpy()
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
|
||||
@unittest.skip("slow")
|
||||
def test_setitem_tensor_indexing_fuzz(self):
|
||||
random.seed(getenv("SEED", 42))
|
||||
for _ in range(getenv("ITERS", 100)):
|
||||
size = random.randint(5, 10)
|
||||
d0, d1, d2 = random.randint(1,5), random.randint(1,5), random.randint(1,5)
|
||||
t = Tensor.zeros(size).contiguous()
|
||||
n = np.zeros((size,))
|
||||
index = Tensor.randint((d0, d1, d2), low=0, high=size)
|
||||
v = Tensor.arange(d0*d1*d2).reshape(d0, d1, d2)
|
||||
t[index] = v
|
||||
n[index.numpy()] = v.numpy()
|
||||
np.testing.assert_allclose(t.numpy(), n, err_msg=f"failed with index={index.numpy().tolist()} and v={v.numpy().tolist()}")
|
||||
|
||||
|
||||
class TestWithGrad(unittest.TestCase):
|
||||
def test_no_requires_grad_works(self):
|
||||
z = Tensor.rand(8, 8)
|
||||
|
||||
@@ -98,7 +98,8 @@ def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
|
||||
|
||||
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor:
|
||||
# reduce such that if mask contains repeated indices the last one remains
|
||||
for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
|
||||
for dim in reversed(axes):
|
||||
mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
|
||||
# remove extra dims from reduce
|
||||
for dim in reversed(axes): mask, values = mask.squeeze(dim), values.squeeze(dim)
|
||||
# select from values for each True element in mask else select from target
|
||||
|
||||
Reference in New Issue
Block a user