mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
comment about multi real and more tests [pr] (#7467)
This commit is contained in:
@@ -772,6 +772,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
|
||||
a = t.shrink(((2, 4), None))
|
||||
b = t.shrink(((6, 8), None))
|
||||
self.assertEqual(a.lazydata.real, [False, True, False, False])
|
||||
self.assertEqual(b.lazydata.real, [False, False, False, True])
|
||||
na = t.numpy()[2:4]
|
||||
nb = t.numpy()[6:8]
|
||||
np.testing.assert_equal(a.numpy(), na)
|
||||
@@ -781,6 +783,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
|
||||
self.assertEqual(c.lazydata.real, [True, True, True, True])
|
||||
expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
|
||||
np.testing.assert_equal(c.numpy(), expected)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user