mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
Added test_chunk and fixed (#1283)
This commit is contained in:
@@ -98,6 +98,31 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True)
|
||||
helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True)
|
||||
|
||||
def test_chunk(self):
|
||||
tor = torch.arange(13).repeat(8, 1).chunk(6, 1)
|
||||
ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 1)
|
||||
assert len(tor) == len(ten)
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
|
||||
|
||||
tor = torch.arange(13).repeat(8, 1).chunk(6, 0)
|
||||
ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 0)
|
||||
assert len(tor) == len(ten)
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
|
||||
|
||||
tor = torch.arange(13).repeat(8, 1).chunk(3, -1)
|
||||
ten = Tensor.arange(13).repeat((8, 1)).chunk(3, -1)
|
||||
assert len(tor) == len(ten)
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
|
||||
|
||||
tor = torch.arange(13).repeat(8, 3, 3).chunk(3, -2)
|
||||
ten = Tensor.arange(13).repeat((8, 3, 3)).chunk(3, -2)
|
||||
assert len(tor) == len(ten)
|
||||
for i in range(len(tor)):
|
||||
helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True)
|
||||
|
||||
def test_arange(self):
|
||||
helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True)
|
||||
helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(10, 5, 3), forward_only=True)
|
||||
|
||||
Reference in New Issue
Block a user