assert error detail in test_assign (#4567)

* use regex assert

* that shouldnt raise
This commit is contained in:
qazal
2024-05-13 14:56:05 +08:00
committed by GitHub
parent 25ec40ca93
commit b0fa97e176

View File

@@ -29,7 +29,6 @@ class TestAssign(unittest.TestCase):
def test_assign_zeros(self):
a = Tensor.zeros(10,10).contiguous()
b = Tensor.zeros(10,10).contiguous()
#with self.assertRaises(RuntimeError):
a.assign(Tensor.ones(10,10))
a.realize()
np.testing.assert_allclose(b.numpy(), 0)
@@ -117,7 +116,7 @@ class TestAssign(unittest.TestCase):
def test_assign_diamond_cycle(self):
# NOTE: should *not* raise AssertionError from numpy
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, "cycle"):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.).contiguous())
@@ -125,7 +124,7 @@ class TestAssign(unittest.TestCase):
np.testing.assert_allclose(new.numpy(), 4)
def test_assign_diamond_contiguous_cycle(self):
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, "cycle"):
a = Tensor.ones(4).contiguous().realize()
times_a = a*3
a.assign(Tensor.full((4,), 2.))
@@ -202,7 +201,7 @@ class TestAssign(unittest.TestCase):
def test_crossunder_assign(self):
# NOTE: should *not* raise AssertionError from numpy
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, "cycle"):
a = Tensor.full((4,), 2).contiguous().realize()
b = Tensor.full((4,), 3).contiguous().realize()
c = a+9
@@ -273,7 +272,7 @@ class TestAssign(unittest.TestCase):
#GlobalCounters.cache = []
ba1 = a.lazydata.base.realized # noqa: F841
bb1 = b.lazydata.base.realized # noqa: F841
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, "contiguous"):
a.assign(a.permute(1,0) + b) # this should not work!
a.realize()
ba2 = a.lazydata.base.realized # noqa: F841
@@ -305,7 +304,7 @@ class TestAssign(unittest.TestCase):
a = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
b = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
# TODO: scheduler limitation, should NOT raise AssertionError from numpy.
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, "contiguous"):
a = a.permute(1, 0)
new_val = a + b
a.assign(new_val)
@@ -314,7 +313,7 @@ class TestAssign(unittest.TestCase):
def test_permuted_reduceop_child_dual_use(self):
a = Tensor.randn(32, 32, 32).realize()
b = Tensor.full((32, 32), 1.).contiguous().realize()
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, "contiguous"):
r = a.sum(axis=1)
b.assign(r + b.permute(1, 0))
b.realize()
@@ -325,7 +324,7 @@ class TestAssign(unittest.TestCase):
c = Tensor.full((32, 32), 2.).contiguous().realize()
# TODO: this is failing in cycle error, it should fail earlier.
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(RuntimeError, "cycle"):
r = a.sum(axis=1)
b_perm = b.permute(1, 0)
b.assign(r + b)