mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -378,17 +378,17 @@ class TestMoveTensor(unittest.TestCase):
|
||||
|
||||
class TestZeroShapeTensor(unittest.TestCase):
|
||||
def test_shape_stride(self):
|
||||
t = Tensor.empty(3, 2, 0)
|
||||
t = Tensor.rand(3, 2, 0)
|
||||
assert t.shape == (3, 2, 0)
|
||||
# numpy has stride 0, 0, 0; torch has stride 2, 1, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 1)
|
||||
|
||||
t = Tensor.empty(3, 0, 2)
|
||||
t = Tensor.rand(3, 0, 2)
|
||||
assert t.shape == (3, 0, 2)
|
||||
# numpy has stride 0, 0, 0; torch has stride 2, 2, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 2, 1)
|
||||
|
||||
t = Tensor.empty(0, 0, 0)
|
||||
t = Tensor.rand(0, 0, 0)
|
||||
assert t.shape == (0, 0, 0)
|
||||
# numpy has stride 0, 0, 0; torch has stride 1, 1, 1
|
||||
assert t.lazydata.st.real_strides() == (0, 0, 1)
|
||||
|
||||
Reference in New Issue
Block a user