From 8aee3f5a9a06c8cfa4ad2b6045397af566f1b707 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 23 May 2024 20:34:40 -0400 Subject: [PATCH] docs: split, chunk, pad2d, flatten, unflatten (#4706) --- docs/tensor.md | 2 - tinygrad/tensor.py | 92 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 82 insertions(+), 12 deletions(-) diff --git a/docs/tensor.md b/docs/tensor.md index 208ee04662..f3d5305bd4 100644 --- a/docs/tensor.md +++ b/docs/tensor.md @@ -64,8 +64,6 @@ ## Movement (high level) -::: tinygrad.Tensor.__getitem__ -::: tinygrad.Tensor.slice ::: tinygrad.Tensor.gather ::: tinygrad.Tensor.cat ::: tinygrad.Tensor.stack diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1ade6031d1..ced826116a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -764,7 +764,7 @@ class Tensor: Returns a tensor that pads the each axis based on input arg. arg has the same length as `self.ndim`. For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`. - If `value` is specified, the tensor is padded with `value`. + If `value` is specified, the tensor is padded with `value` instead of `0.0`. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) @@ -925,8 +925,8 @@ class Tensor: v = v.cast(assign_to.dtype)._broadcast_to(_broadcast_shape(assign_to.shape, v.shape)).contiguous() assign_to.assign(v).realize() - # NOTE: using slice is discouraged and things should migrate to pad and shrink - def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor: + # NOTE: using _slice is discouraged and things should migrate to pad and shrink + def _slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor: arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg)) padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_)) return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding))) @@ -1016,17 +1016,53 @@ class Tensor: return dim + self.ndim+outer if dim < 0 else dim def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]: + """ + Splits the tensor into chunks along the dimension specified by `dim`. + If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller. + If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(10).reshape(5, 2) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + split = t.split(2) + print("\\n".join([repr(x.numpy()) for x in split])) + ``` + ```python exec="true" source="above" session="tensor" result="python" + split = t.split([1, 4]) + print("\\n".join([repr(x.numpy()) for x in split])) + ``` + """ assert all_int(self.shape), f"does not support symbolic shape {self.shape}" dim = self._resolve_dim(dim) if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))] assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}" return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))]) - def chunk(self, num:int, dim:int=0) -> List[Tensor]: + def chunk(self, chunks:int, dim:int=0) -> List[Tensor]: + """ + Splits the tensor into `chunks` number of chunks along the dimension `dim`. + If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one. + The function may return fewer than the specified number of chunks. + + ```python exec="true" source="above" session="tensor" result="python" + chunked = Tensor.arange(11).chunk(6) + print("\\n".join([repr(x.numpy()) for x in chunked])) + ``` + ```python exec="true" source="above" session="tensor" result="python" + chunked = Tensor.arange(12).chunk(6) + print("\\n".join([repr(x.numpy()) for x in chunked])) + ``` + ```python exec="true" source="above" session="tensor" result="python" + chunked = Tensor.arange(13).chunk(6) + print("\\n".join([repr(x.numpy()) for x in chunked])) + ``` + """ assert all_int(self.shape), f"does not support symbolic shape {self.shape}" - assert num > 0, f"expect num to be greater than 0, got: {num}" + assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}" dim = self._resolve_dim(dim) - return list(self.split(math.ceil(self.shape[dim]/num) if self.shape[dim] else [0]*num, dim=dim)) + return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim)) def squeeze(self, dim:Optional[int]=None) -> Tensor: """ @@ -1063,10 +1099,21 @@ class Tensor: dim = self._resolve_dim(dim, outer=True) return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) - # (padding_left, padding_right, padding_top, padding_bottom) - def pad2d(self, padding:Sequence[int], value:float=0) -> Tensor: + def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor: + """ + Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom). + If `value` is specified, the tensor is padded with `value` instead of `0.0`. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(9).reshape(1, 1, 3, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy()) + ``` + """ slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1] - return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value) + return self._slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value) @property def T(self) -> Tensor: @@ -1088,10 +1135,35 @@ class Tensor: return self.permute(order) def flatten(self, start_dim=0, end_dim=-1): + """ + Flattens the tensor by reshaping it into a one-dimensional tensor. + If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(8).reshape(2, 2, 2) + print(t.flatten().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.flatten(start_dim=1).numpy()) + ``` + """ start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim) return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:]) def unflatten(self, dim:int, sizes:Tuple[int,...]): + """ + Expands dimension `dim` of the tensor over multiple dimensions specified by `sizes`. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape) + ``` + """ dim = self._resolve_dim(dim) return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:]) @@ -2111,7 +2183,7 @@ class Tensor: # padding padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]]) - x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None)) + x = x._slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None)) # prepare input x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)