docs: split, chunk, pad2d, flatten, unflatten (#4706)

This commit is contained in:
chenyu
2024-05-23 20:34:40 -04:00
committed by GitHub
parent 2c56aa7fe0
commit 8aee3f5a9a
2 changed files with 82 additions and 12 deletions

View File

@@ -64,8 +64,6 @@
## Movement (high level)
::: tinygrad.Tensor.__getitem__
::: tinygrad.Tensor.slice
::: tinygrad.Tensor.gather
::: tinygrad.Tensor.cat
::: tinygrad.Tensor.stack

View File

@@ -764,7 +764,7 @@ class Tensor:
Returns a tensor that pads the each axis based on input arg.
arg has the same length as `self.ndim`.
For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
If `value` is specified, the tensor is padded with `value`.
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(6).reshape(2, 3)
@@ -925,8 +925,8 @@ class Tensor:
v = v.cast(assign_to.dtype)._broadcast_to(_broadcast_shape(assign_to.shape, v.shape)).contiguous()
assign_to.assign(v).realize()
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
# NOTE: using _slice is discouraged and things should migrate to pad and shrink
def _slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
@@ -1016,17 +1016,53 @@ class Tensor:
return dim + self.ndim+outer if dim < 0 else dim
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
"""
Splits the tensor into chunks along the dimension specified by `dim`.
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(10).reshape(5, 2)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
split = t.split(2)
print("\\n".join([repr(x.numpy()) for x in split]))
```
```python exec="true" source="above" session="tensor" result="python"
split = t.split([1, 4])
print("\\n".join([repr(x.numpy()) for x in split]))
```
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim = self._resolve_dim(dim)
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
def chunk(self, num:int, dim:int=0) -> List[Tensor]:
def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
"""
Splits the tensor into `chunks` number of chunks along the dimension `dim`.
If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
The function may return fewer than the specified number of chunks.
```python exec="true" source="above" session="tensor" result="python"
chunked = Tensor.arange(11).chunk(6)
print("\\n".join([repr(x.numpy()) for x in chunked]))
```
```python exec="true" source="above" session="tensor" result="python"
chunked = Tensor.arange(12).chunk(6)
print("\\n".join([repr(x.numpy()) for x in chunked]))
```
```python exec="true" source="above" session="tensor" result="python"
chunked = Tensor.arange(13).chunk(6)
print("\\n".join([repr(x.numpy()) for x in chunked]))
```
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
assert num > 0, f"expect num to be greater than 0, got: {num}"
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
dim = self._resolve_dim(dim)
return list(self.split(math.ceil(self.shape[dim]/num) if self.shape[dim] else [0]*num, dim=dim))
return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim))
def squeeze(self, dim:Optional[int]=None) -> Tensor:
"""
@@ -1063,10 +1099,21 @@ class Tensor:
dim = self._resolve_dim(dim, outer=True)
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Sequence[int], value:float=0) -> Tensor:
def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
"""
Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
```
"""
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
return self._slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
@property
def T(self) -> Tensor:
@@ -1088,10 +1135,35 @@ class Tensor:
return self.permute(order)
def flatten(self, start_dim=0, end_dim=-1):
"""
Flattens the tensor by reshaping it into a one-dimensional tensor.
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(8).reshape(2, 2, 2)
print(t.flatten().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.flatten(start_dim=1).numpy())
```
"""
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
def unflatten(self, dim:int, sizes:Tuple[int,...]):
"""
Expands dimension `dim` of the tensor over multiple dimensions specified by `sizes`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
```
"""
dim = self._resolve_dim(dim)
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
@@ -2111,7 +2183,7 @@ class Tensor:
# padding
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
x = x._slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
# prepare input
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)