fix advanced setitem with 1 in shape (#7797)

* fix advanced setitem with 1 in shape

* linter
This commit is contained in:
ttomsa
2024-11-20 01:04:59 +00:00
committed by GitHub
parent d800a79112
commit 9adeb1041c
2 changed files with 9 additions and 2 deletions

View File

@@ -94,6 +94,11 @@ class TestSetitem(unittest.TestCase):
t[[1,1]] = Tensor([1,0])
np.testing.assert_allclose(t.numpy(), [1,0,3,4])
def test_setitem_with_1_in_shape(self):
t = Tensor([[1],[2],[3]])
t[[0,0]] = Tensor([[1],[2]])
np.testing.assert_allclose(t.numpy(), [[2],[2],[3]])
def test_fancy_setitem(self):
t = Tensor.zeros(6,6).contiguous()
t[[1,2], [3,2]] = 3