raise TypeError calling len() on a 0-d tensor (#4970)

matched numpy and torch
This commit is contained in:
chenyu
2024-06-14 16:34:27 -04:00
committed by GitHub
parent 67e8df4969
commit 64cda3c481
2 changed files with 5 additions and 1 deletions

View File

@@ -204,12 +204,14 @@ class TestTinygrad(unittest.TestCase):
assert Tensor.randn(1,1,1,1,1,1).numel() == 1
assert Tensor([]).numel() == 0
assert Tensor.randn(1,0,2,5).numel() == 0
assert Tensor(3).numel() == 1
def test_len(self):
assert len(torch.zeros(7)) == len(Tensor.zeros(7))
assert len(torch.zeros(10,20)) == len(Tensor.zeros(10,20))
assert len(torch.zeros(10,20)) == len(Tensor.zeros(10,20,30))
assert len(torch.zeros(1).flatten()) == len(Tensor.zeros(1).flatten())
with self.assertRaises(TypeError): len(Tensor(3))
def test_size(self):
t1, t2 = torch.zeros(10,20), Tensor.zeros(10,20)