comment about multi real and more tests [pr] (#7467)

This commit is contained in:
chenyu
2024-11-01 11:49:11 -04:00
committed by GitHub
parent 1f343aa40e
commit 18e159c9ac
2 changed files with 7 additions and 2 deletions

View File

@@ -772,6 +772,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
a = t.shrink(((2, 4), None))
b = t.shrink(((6, 8), None))
self.assertEqual(a.lazydata.real, [False, True, False, False])
self.assertEqual(b.lazydata.real, [False, False, False, True])
na = t.numpy()[2:4]
nb = t.numpy()[6:8]
np.testing.assert_equal(a.numpy(), na)
@@ -781,6 +783,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
c = a + b
c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
self.assertEqual(c.lazydata.real, [True, True, True, True])
expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
np.testing.assert_equal(c.numpy(), expected)