setitem in-place operator tests (#4577)

* tests and error

* rename to in-place

* add a note

* more comments

* more comments

* disable folded advanced setitem tests for now
This commit is contained in:
geohotstan
2024-05-14 13:28:02 +08:00
committed by GitHub
parent 0fa57b8ce9
commit 089eeec271
3 changed files with 43 additions and 0 deletions

View File

@@ -42,6 +42,41 @@ class TestSetitem(unittest.TestCase):
assert not t.lazydata.st.contiguous
with self.assertRaises(AssertionError): t[1] = 5
def test_setitem_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 5]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] -= 1
np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 2]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] *= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 6]])
# NOTE: have to manually cast setitem target to least_upper_float for div
t = Tensor.arange(4, dtype=dtypes.float).reshape(2, 2).contiguous()
t[1] /= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [1, 1.5]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] **= 2
np.testing.assert_allclose(t.numpy(), [[0, 1], [4, 9]])
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] ^= 5
np.testing.assert_allclose(t.numpy(), [[0, 1], [7, 6]])
@unittest.expectedFailure
def test_setitem_consecutive_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2
t = t.contiguous()
# TODO: RuntimeError: must be contiguous for assign ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=2, mask=None, contiguous=False),))
t[1] -= 1
np.testing.assert_allclose(t.numpy(), [[0, 1], [3, 4]])
# TODO: implement fancy setitem
@unittest.expectedFailure
def test_fancy_setitem(self):