mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
assert error detail in test_assign (#4567)
* use regex assert * that shouldnt raise
This commit is contained in:
@@ -29,7 +29,6 @@ class TestAssign(unittest.TestCase):
|
||||
def test_assign_zeros(self):
|
||||
a = Tensor.zeros(10,10).contiguous()
|
||||
b = Tensor.zeros(10,10).contiguous()
|
||||
#with self.assertRaises(RuntimeError):
|
||||
a.assign(Tensor.ones(10,10))
|
||||
a.realize()
|
||||
np.testing.assert_allclose(b.numpy(), 0)
|
||||
@@ -117,7 +116,7 @@ class TestAssign(unittest.TestCase):
|
||||
|
||||
def test_assign_diamond_cycle(self):
|
||||
# NOTE: should *not* raise AssertionError from numpy
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaisesRegex(RuntimeError, "cycle"):
|
||||
a = Tensor.ones(4).contiguous().realize()
|
||||
times_a = a*3
|
||||
a.assign(Tensor.full((4,), 2.).contiguous())
|
||||
@@ -125,7 +124,7 @@ class TestAssign(unittest.TestCase):
|
||||
np.testing.assert_allclose(new.numpy(), 4)
|
||||
|
||||
def test_assign_diamond_contiguous_cycle(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaisesRegex(RuntimeError, "cycle"):
|
||||
a = Tensor.ones(4).contiguous().realize()
|
||||
times_a = a*3
|
||||
a.assign(Tensor.full((4,), 2.))
|
||||
@@ -202,7 +201,7 @@ class TestAssign(unittest.TestCase):
|
||||
|
||||
def test_crossunder_assign(self):
|
||||
# NOTE: should *not* raise AssertionError from numpy
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaisesRegex(RuntimeError, "cycle"):
|
||||
a = Tensor.full((4,), 2).contiguous().realize()
|
||||
b = Tensor.full((4,), 3).contiguous().realize()
|
||||
c = a+9
|
||||
@@ -273,7 +272,7 @@ class TestAssign(unittest.TestCase):
|
||||
#GlobalCounters.cache = []
|
||||
ba1 = a.lazydata.base.realized # noqa: F841
|
||||
bb1 = b.lazydata.base.realized # noqa: F841
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaisesRegex(RuntimeError, "contiguous"):
|
||||
a.assign(a.permute(1,0) + b) # this should not work!
|
||||
a.realize()
|
||||
ba2 = a.lazydata.base.realized # noqa: F841
|
||||
@@ -305,7 +304,7 @@ class TestAssign(unittest.TestCase):
|
||||
a = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
|
||||
b = Tensor.arange(4 * 4).reshape(4, 4).contiguous().realize()
|
||||
# TODO: scheduler limitation, should NOT raise AssertionError from numpy.
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaisesRegex(RuntimeError, "contiguous"):
|
||||
a = a.permute(1, 0)
|
||||
new_val = a + b
|
||||
a.assign(new_val)
|
||||
@@ -314,7 +313,7 @@ class TestAssign(unittest.TestCase):
|
||||
def test_permuted_reduceop_child_dual_use(self):
|
||||
a = Tensor.randn(32, 32, 32).realize()
|
||||
b = Tensor.full((32, 32), 1.).contiguous().realize()
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaisesRegex(RuntimeError, "contiguous"):
|
||||
r = a.sum(axis=1)
|
||||
b.assign(r + b.permute(1, 0))
|
||||
b.realize()
|
||||
@@ -325,7 +324,7 @@ class TestAssign(unittest.TestCase):
|
||||
c = Tensor.full((32, 32), 2.).contiguous().realize()
|
||||
|
||||
# TODO: this is failing in cycle error, it should fail earlier.
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaisesRegex(RuntimeError, "cycle"):
|
||||
r = a.sum(axis=1)
|
||||
b_perm = b.permute(1, 0)
|
||||
b.assign(r + b)
|
||||
|
||||
Reference in New Issue
Block a user