diff --git a/test/test_ops.py b/test/test_ops.py index 1d4c8f67a2..10fbd9dc6d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -98,6 +98,31 @@ class TestOps(unittest.TestCase): helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True) + def test_chunk(self): + tor = torch.arange(13).repeat(8, 1).chunk(6, 1) + ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 1) + assert len(tor) == len(ten) + for i in range(len(tor)): + helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) + + tor = torch.arange(13).repeat(8, 1).chunk(6, 0) + ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 0) + assert len(tor) == len(ten) + for i in range(len(tor)): + helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) + + tor = torch.arange(13).repeat(8, 1).chunk(3, -1) + ten = Tensor.arange(13).repeat((8, 1)).chunk(3, -1) + assert len(tor) == len(ten) + for i in range(len(tor)): + helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) + + tor = torch.arange(13).repeat(8, 3, 3).chunk(3, -2) + ten = Tensor.arange(13).repeat((8, 3, 3)).chunk(3, -2) + assert len(tor) == len(ten) + for i in range(len(tor)): + helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) + def test_arange(self): helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(10, 5, 3), forward_only=True) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fb48361e07..544b3652b7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -340,10 +340,10 @@ class Tensor: # TODO: make this nicer with syntactic sugar in slice def chunk(self, num, dim): - slice_params = [[(0, s) for s in self.shape] for _ in range(num)] - for i,k in enumerate(range(0, self.shape[dim], self.shape[dim]//num)): - slice_params[i][dim] = (k, min(self.shape[dim], k+self.shape[dim]//num)) - return [self.slice(p) for p in slice_params] + slice_params = [[slice(None) for s in self.shape] for _ in range(ceil(self.shape[dim]/ceil(self.shape[dim]/num)))] + for i, k in enumerate(range(0, self.shape[dim], ceil(self.shape[dim]/num))): + slice_params[i][dim] = slice(k, k + ceil(self.shape[dim]/num)) + return [self[tuple(sl)] for sl in slice_params] def squeeze(self, dim=None): if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1])