From 58fa82eef502d6d3ec4be3d6cf76aa0580ae1326 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 17 Feb 2026 08:36:09 -0500 Subject: [PATCH] stronger test_assign_add (#14826) also test self add 10 and 100 times --- test/unit/test_assign.py | 54 +++++++++++++++------------------------- 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index 1e3ab08280..36c061d47a 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -36,42 +36,28 @@ class TestAssign(unittest.TestCase): np.testing.assert_allclose(b.numpy(), 0) def test_assign_add(self): - x = Tensor([0]).realize() - buf = x.uop.base.realized - x += 1 - x.realize() - assert x.item() == 1 - assert x.uop.base.realized is buf - - def test_assign_add_twice(self): - # NOTE: this has two kernels - x = Tensor([0]).realize() - buf = x.uop.base.realized - x += 1 - x += 1 - x.realize() - assert x.item() == 2 - # TODO: both assigns should write to the original buffer, not create a new one - with self.assertRaises(AssertionError): - assert x.uop.base.realized is buf + for T in (1, 2, 10, 100): + x = Tensor([0]).realize() + buf = x.uop.base.realized + for _ in range(T): + x += 1 + x.realize() + assert x.item() == T + if T == 1: + assert x.uop.base.realized is buf + else: + # TODO: this is wrong, it should always return the same buffer + assert x.uop.base.realized is not buf def test_assign_slice_add(self): - x = Tensor([0, 0]).realize() - buf = x.uop.base.realized - x[0] += 1 - x.realize() - assert x.tolist() == [1, 0] - assert x.uop.base.realized is buf - - def test_assign_slice_add_twice(self): - # NOTE: this has two kernels - x = Tensor([0, 0]).realize() - buf = x.uop.base.realized - x[0] += 1 - x[0] += 1 - x.realize() - assert x.tolist() == [2, 0] - assert x.uop.base.realized is buf + for T in (1, 2, 10, 100): + x = Tensor([0, 0]).realize() + buf = x.uop.base.realized + for _ in range(T): + x[0] += 1 + x.realize() + assert x.tolist() == [T, 0] + assert x.uop.base.realized is buf def test_assign_add_double(self): def f(x):