mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Fix advanced tensor indexing setitem (#12128)
* Add failure test case for advanced tensor indexing setitem * Fix advanced tensor indexing setitem when permuted * Reduce line count * Revert unnecessary change * Combine two lines into one
This commit is contained in:
@@ -165,6 +165,17 @@ class TestSetitem(unittest.TestCase):
|
||||
t[idx] = val
|
||||
self.assertEqual(t.tolist(), [val]*idx_size+[idx_size])
|
||||
|
||||
def test_setitem_advanced_indexing(self):
|
||||
# Example from https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
||||
t = Tensor.zeros(10,20,30,40,50).contiguous()
|
||||
ind_1 = Tensor([5,3,7,8])
|
||||
ind_2 = Tensor([[[0],[1],[2]],[[3],[4],[5]]])
|
||||
v = Tensor.arange(2*3*4*10*30*50).reshape(2,3,4,10,30,50)
|
||||
t[:, ind_1, :, ind_2, :] = v
|
||||
n = np.zeros((10,20,30,40,50))
|
||||
n[:, ind_1.numpy(), :, ind_2.numpy(), :] = v.numpy()
|
||||
np.testing.assert_allclose(t.numpy(), n)
|
||||
|
||||
class TestWithGrad(unittest.TestCase):
|
||||
def test_no_requires_grad_works(self):
|
||||
z = Tensor.rand(8, 8)
|
||||
|
||||
Reference in New Issue
Block a user