mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add Tensor.split (#2750)
* add Tensor.split (#2677) * fix mypy errors * add list support for Tensor.split * fix ruff comments * match tensor.split api * simplify split and test_split --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -104,6 +104,22 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True)
|
||||
helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True)
|
||||
|
||||
def test_split(self):
|
||||
test_cases = [
|
||||
(torch.arange(10), Tensor.arange(10), 5),
|
||||
(torch.arange(10), Tensor.arange(10), [1, 4, 5]),
|
||||
(torch.arange(10), Tensor.arange(10), 3),
|
||||
(torch.arange(12).reshape(3, 4), Tensor.arange(12).reshape(3, 4), 1),
|
||||
(torch.arange(16).reshape(4, 4), Tensor.arange(16).reshape(4, 4), [2, 2]),
|
||||
(torch.arange(10000), Tensor.arange(10000), 2500),
|
||||
]
|
||||
|
||||
for tor, ten, sizes in test_cases:
|
||||
tor_splits, ten_splits = tor.split(sizes), ten.split(sizes)
|
||||
assert len(tor_splits) == len(ten_splits)
|
||||
for tor_chunk, ten_chunk in zip(tor_splits, ten_splits):
|
||||
helper_test_op([], lambda: tor_chunk, lambda: ten_chunk, forward_only=True)
|
||||
|
||||
def test_chunk(self):
|
||||
tor = torch.arange(13).repeat(8, 1).chunk(6, 1)
|
||||
ten = Tensor.arange(13).repeat((8, 1)).chunk(6, 1)
|
||||
|
||||
@@ -444,6 +444,12 @@ class Tensor:
|
||||
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
||||
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
||||
|
||||
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
dim = dim + self.ndim if dim < 0 else dim
|
||||
if isinstance(sizes, int): return tuple(self.chunk(math.ceil(self.shape[dim]/sizes)))
|
||||
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
|
||||
|
||||
def chunk(self, num:int, dim:int=0) -> List[Tensor]:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
|
||||
|
||||
Reference in New Issue
Block a user