add Tensor.split (#2750)

* add Tensor.split (#2677)

* fix mypy errors

* add list support for Tensor.split

* fix ruff comments

* match tensor.split api

* simplify split and test_split

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Kevin Herro
2024-01-02 00:09:04 -06:00
committed by GitHub
parent e7a432b479
commit bd6a0c90a0
2 changed files with 22 additions and 0 deletions

View File

@@ -104,6 +104,22 @@ class TestOps(unittest.TestCase):
helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True)
helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True)
def test_split(self):
test_cases = [
(torch.arange(10), Tensor.arange(10), 5),
(torch.arange(10), Tensor.arange(10), [1, 4, 5]),
(torch.arange(10), Tensor.arange(10), 3),
(torch.arange(12).reshape(3, 4), Tensor.arange(12).reshape(3, 4), 1),
(torch.arange(16).reshape(4, 4), Tensor.arange(16).reshape(4, 4), [2, 2]),
(torch.arange(10000), Tensor.arange(10000), 2500),
]
for tor, ten, sizes in test_cases:
tor_splits, ten_splits = tor.split(sizes), ten.split(sizes)
assert len(tor_splits) == len(ten_splits)
for tor_chunk, ten_chunk in zip(tor_splits, ten_splits):
helper_test_op([], lambda: tor_chunk, lambda: ten_chunk, forward_only=True)
def test_chunk(self):
tor = torch.arange(13).repeat(8, 1).chunk(6, 1)
ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 1)

View File

@@ -444,6 +444,12 @@ class Tensor:
final_shape = [r*s for r,s in zip(repeats, base_shape)]
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim = dim + self.ndim if dim < 0 else dim
if isinstance(sizes, int): return tuple(self.chunk(math.ceil(self.shape[dim]/sizes)))
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
def chunk(self, num:int, dim:int=0) -> List[Tensor]:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)