mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
raise TypeError calling len() on a 0-d tensor (#4970)
matched numpy and torch
This commit is contained in:
@@ -204,12 +204,14 @@ class TestTinygrad(unittest.TestCase):
|
||||
assert Tensor.randn(1,1,1,1,1,1).numel() == 1
|
||||
assert Tensor([]).numel() == 0
|
||||
assert Tensor.randn(1,0,2,5).numel() == 0
|
||||
assert Tensor(3).numel() == 1
|
||||
|
||||
def test_len(self):
|
||||
assert len(torch.zeros(7)) == len(Tensor.zeros(7))
|
||||
assert len(torch.zeros(10,20)) == len(Tensor.zeros(10,20))
|
||||
assert len(torch.zeros(10,20)) == len(Tensor.zeros(10,20,30))
|
||||
assert len(torch.zeros(1).flatten()) == len(Tensor.zeros(1).flatten())
|
||||
with self.assertRaises(TypeError): len(Tensor(3))
|
||||
|
||||
def test_size(self):
|
||||
t1, t2 = torch.zeros(10,20), Tensor.zeros(10,20)
|
||||
|
||||
Reference in New Issue
Block a user