mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
docs: split, chunk, pad2d, flatten, unflatten (#4706)
This commit is contained in:
@@ -64,8 +64,6 @@
|
||||
|
||||
## Movement (high level)
|
||||
|
||||
::: tinygrad.Tensor.__getitem__
|
||||
::: tinygrad.Tensor.slice
|
||||
::: tinygrad.Tensor.gather
|
||||
::: tinygrad.Tensor.cat
|
||||
::: tinygrad.Tensor.stack
|
||||
|
||||
@@ -764,7 +764,7 @@ class Tensor:
|
||||
Returns a tensor that pads the each axis based on input arg.
|
||||
arg has the same length as `self.ndim`.
|
||||
For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`.
|
||||
If `value` is specified, the tensor is padded with `value`.
|
||||
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(6).reshape(2, 3)
|
||||
@@ -925,8 +925,8 @@ class Tensor:
|
||||
v = v.cast(assign_to.dtype)._broadcast_to(_broadcast_shape(assign_to.shape, v.shape)).contiguous()
|
||||
assign_to.assign(v).realize()
|
||||
|
||||
# NOTE: using slice is discouraged and things should migrate to pad and shrink
|
||||
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
|
||||
# NOTE: using _slice is discouraged and things should migrate to pad and shrink
|
||||
def _slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
|
||||
arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg))
|
||||
padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_))
|
||||
return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding)))
|
||||
@@ -1016,17 +1016,53 @@ class Tensor:
|
||||
return dim + self.ndim+outer if dim < 0 else dim
|
||||
|
||||
def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]:
|
||||
"""
|
||||
Splits the tensor into chunks along the dimension specified by `dim`.
|
||||
If `sizes` is an integer, it splits into equally sized chunks if possible, otherwise the last chunk will be smaller.
|
||||
If `sizes` is a list, it splits into `len(sizes)` chunks with size in `dim` according to `size`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(10).reshape(5, 2)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
split = t.split(2)
|
||||
print("\\n".join([repr(x.numpy()) for x in split]))
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
split = t.split([1, 4])
|
||||
print("\\n".join([repr(x.numpy()) for x in split]))
|
||||
```
|
||||
"""
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
dim = self._resolve_dim(dim)
|
||||
if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))]
|
||||
assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}"
|
||||
return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))])
|
||||
|
||||
def chunk(self, num:int, dim:int=0) -> List[Tensor]:
|
||||
def chunk(self, chunks:int, dim:int=0) -> List[Tensor]:
|
||||
"""
|
||||
Splits the tensor into `chunks` number of chunks along the dimension `dim`.
|
||||
If the tensor size along `dim` is not divisible by `chunks`, all returned chunks will be the same size except the last one.
|
||||
The function may return fewer than the specified number of chunks.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
chunked = Tensor.arange(11).chunk(6)
|
||||
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
chunked = Tensor.arange(12).chunk(6)
|
||||
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
chunked = Tensor.arange(13).chunk(6)
|
||||
print("\\n".join([repr(x.numpy()) for x in chunked]))
|
||||
```
|
||||
"""
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
assert num > 0, f"expect num to be greater than 0, got: {num}"
|
||||
assert chunks > 0, f"expect chunks to be greater than 0, got: {chunks}"
|
||||
dim = self._resolve_dim(dim)
|
||||
return list(self.split(math.ceil(self.shape[dim]/num) if self.shape[dim] else [0]*num, dim=dim))
|
||||
return list(self.split(math.ceil(self.shape[dim]/chunks) if self.shape[dim] else [0]*chunks, dim=dim))
|
||||
|
||||
def squeeze(self, dim:Optional[int]=None) -> Tensor:
|
||||
"""
|
||||
@@ -1063,10 +1099,21 @@ class Tensor:
|
||||
dim = self._resolve_dim(dim, outer=True)
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
def pad2d(self, padding:Sequence[int], value:float=0) -> Tensor:
|
||||
def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor:
|
||||
"""
|
||||
Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom).
|
||||
If `value` is specified, the tensor is padded with `value` instead of `0.0`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(9).reshape(1, 1, 3, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy())
|
||||
```
|
||||
"""
|
||||
slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1]
|
||||
return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
|
||||
return self._slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value)
|
||||
|
||||
@property
|
||||
def T(self) -> Tensor:
|
||||
@@ -1088,10 +1135,35 @@ class Tensor:
|
||||
return self.permute(order)
|
||||
|
||||
def flatten(self, start_dim=0, end_dim=-1):
|
||||
"""
|
||||
Flattens the tensor by reshaping it into a one-dimensional tensor.
|
||||
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(8).reshape(2, 2, 2)
|
||||
print(t.flatten().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.flatten(start_dim=1).numpy())
|
||||
```
|
||||
"""
|
||||
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
||||
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
||||
|
||||
def unflatten(self, dim:int, sizes:Tuple[int,...]):
|
||||
"""
|
||||
Expands dimension `dim` of the tensor over multiple dimensions specified by `sizes`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
|
||||
```
|
||||
"""
|
||||
dim = self._resolve_dim(dim)
|
||||
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
||||
|
||||
@@ -2111,7 +2183,7 @@ class Tensor:
|
||||
|
||||
# padding
|
||||
padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]])
|
||||
x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
|
||||
x = x._slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None))
|
||||
|
||||
# prepare input
|
||||
x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
||||
|
||||
Reference in New Issue
Block a user