change Tensor.stack to method (#4719)

This commit is contained in:
chenyu
2024-05-24 17:04:19 -04:00
committed by GitHub
parent ba116ff630
commit 31358cbea5
15 changed files with 40 additions and 42 deletions

View File

@@ -131,13 +131,13 @@ class TestIndexing(unittest.TestCase):
# indexing with step
reference = consec((10, 10, 10))
numpy_testing_assert_equal_helper(reference[1:5:2], Tensor.stack([reference[1], reference[3]], 0))
numpy_testing_assert_equal_helper(reference[1:6:2], Tensor.stack([reference[1], reference[3], reference[5]], 0))
numpy_testing_assert_equal_helper(reference[1:9:4], Tensor.stack([reference[1], reference[5]], 0))
numpy_testing_assert_equal_helper(reference[2:4, 1:5:2], Tensor.stack([reference[2:4, 1], reference[2:4, 3]], 1))
numpy_testing_assert_equal_helper(reference[3, 1:6:2], Tensor.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0))
numpy_testing_assert_equal_helper(reference[None, 2, 1:9:4], Tensor.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0))
numpy_testing_assert_equal_helper(reference[:, 2, 1:6:2], Tensor.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1))
numpy_testing_assert_equal_helper(reference[1:5:2], Tensor.stack(reference[1], reference[3], dim=0))
numpy_testing_assert_equal_helper(reference[1:6:2], Tensor.stack(reference[1], reference[3], reference[5], dim=0))
numpy_testing_assert_equal_helper(reference[1:9:4], Tensor.stack(reference[1], reference[5], dim=0))
numpy_testing_assert_equal_helper(reference[2:4, 1:5:2], Tensor.stack(reference[2:4, 1], reference[2:4, 3], dim=1))
numpy_testing_assert_equal_helper(reference[3, 1:6:2], Tensor.stack(reference[3, 1], reference[3, 3], reference[3, 5], dim=0))
numpy_testing_assert_equal_helper(reference[None, 2, 1:9:4], Tensor.stack(reference[2, 1], reference[2, 5], dim=0).unsqueeze(0))
numpy_testing_assert_equal_helper(reference[:, 2, 1:6:2], Tensor.stack(reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5], dim=1))
lst = [list(range(i, i+10)) for i in range(0, 100, 10)]
tensor = Tensor(lst)

View File

@@ -393,7 +393,7 @@ class TestLinearizer(unittest.TestCase):
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack([a, b])
r = Tensor.stack(a, b)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
@@ -890,7 +890,7 @@ class TestHandCodedOpts(unittest.TestCase):
assert k.upcasted == 1 and k.full_shape[-1] == 7
def test_masked_upcast_wino(self):
monster = Tensor.stack([Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
s = create_schedule([monster.lazydata])[-1]
k = Linearizer(*s.ast)

View File

@@ -1062,9 +1062,9 @@ class TestOps(unittest.TestCase):
lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[2:4,:,:])
def test_stack_slice(self):
helper_test_op([(4)], lambda x: torch.stack([x for i in range(3)])[0,:], lambda x: Tensor.stack([x for i in range(3)])[0,:])
helper_test_op([(5)], lambda x: torch.stack([x for i in range(3)])[0,0], lambda x: Tensor.stack([x for i in range(3)])[0,0])
helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack([x for i in range(4)])[3])
helper_test_op([(4)], lambda x: torch.stack([x for i in range(3)])[0,:], lambda x: Tensor.stack(*[x for i in range(3)])[0,:])
helper_test_op([(5)], lambda x: torch.stack([x for i in range(3)])[0,0], lambda x: Tensor.stack(*[x for i in range(3)])[0,0])
helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack(*[x for i in range(4)])[3])
def test_transpose(self):
helper_test_op([(3,3)], lambda x: x.T)
@@ -1554,13 +1554,13 @@ class TestOps(unittest.TestCase):
def test_stack(self):
for dim in range(-1, 3):
helper_test_op([(45,65,3), (45,65,3), (45,65,3)], lambda x, y, z: torch.stack((x, y, z), dim), lambda x, y, z: Tensor.stack([x, y, z], dim))
helper_test_op([(45,65,3), (45,65,3), (45,65,3)], lambda x, y, z: torch.stack((x, y, z), dim), lambda x, y, z: Tensor.stack(x, y, z, dim=dim))
with self.assertRaises(IndexError):
Tensor.stack([Tensor.randn(45, 65, 3)], dim=77)
Tensor.stack(Tensor.randn(45, 65, 3), dim=77)
a = Tensor(3.14)
np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy())
np.testing.assert_allclose(Tensor.stack(a, a).numpy(), Tensor([3.14, 3.14]).numpy())
def test_repeat(self):
x = Tensor.randn(4, 6, 3)