|
|
|
|
@@ -1033,16 +1033,17 @@ class Tensor(SimpleMathTrait):
|
|
|
|
|
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Returns a tensor with padding applied based on the input `padding`.
|
|
|
|
|
|
|
|
|
|
`padding` supports two padding structures:
|
|
|
|
|
|
|
|
|
|
1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...)
|
|
|
|
|
- This structure matches PyTorch's pad.
|
|
|
|
|
- `padding` length must be even.
|
|
|
|
|
1. Flat padding: `(padding_left, padding_right, padding_top, padding_bottom, ...)`
|
|
|
|
|
- This structure matches PyTorch's pad.
|
|
|
|
|
- `padding` length must be even.
|
|
|
|
|
|
|
|
|
|
2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right))
|
|
|
|
|
- This structure matches pad for jax, numpy, tensorflow and others.
|
|
|
|
|
- For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
|
|
|
|
|
- `padding` must have the same length as `self.ndim`.
|
|
|
|
|
2. Group padding: `(..., (padding_top, padding_bottom), (padding_left, padding_right))`
|
|
|
|
|
- This structure matches pad for JAX, NumPy, TensorFlow, and others.
|
|
|
|
|
- For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`.
|
|
|
|
|
- `padding` must have the same length as `self.ndim`.
|
|
|
|
|
|
|
|
|
|
Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices.
|
|
|
|
|
Padding modes is selected with `mode` which supports `constant`, `reflect` and `replicate`.
|
|
|
|
|
@@ -1090,35 +1091,6 @@ class Tensor(SimpleMathTrait):
|
|
|
|
|
return X.shrink(tuple((-min(pB,0), min(pA+s,s)) for (pB,pA),s in zip(pX, X.shape)))
|
|
|
|
|
|
|
|
|
|
# ***** movement high level ops *****
|
|
|
|
|
|
|
|
|
|
# Supported Indexing Implementations:
|
|
|
|
|
# 1. Int indexing (no copy)
|
|
|
|
|
# - for all dims where there's int, shrink -> reshape
|
|
|
|
|
# - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
|
|
|
|
|
# - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9)
|
|
|
|
|
# - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,)
|
|
|
|
|
# 2. Slice indexing (no copy)
|
|
|
|
|
# - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink
|
|
|
|
|
# - first shrink the Tensor to X.shrink(((start, end),))
|
|
|
|
|
# - then we apply stride through Optional[flip] -> pad -> reshape -> shrink
|
|
|
|
|
# - flip where dim value is negative
|
|
|
|
|
# - pad on dims to be multiple of strides, such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible
|
|
|
|
|
# - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1]
|
|
|
|
|
# - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride
|
|
|
|
|
# 3. None indexing (no copy)
|
|
|
|
|
# - reshape (inject) a dim at the dim where there's None
|
|
|
|
|
# 4. Tensor indexing (copy)
|
|
|
|
|
# - use Tensor.arange == tensor_index to create masks for dims with Tensors (adds a dim for each mask)
|
|
|
|
|
# - combine masks together with mul
|
|
|
|
|
# - apply mask to self by mask * self
|
|
|
|
|
# - sum reduce away the extra dims added from creating masks
|
|
|
|
|
# Tiny Things:
|
|
|
|
|
# 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis]
|
|
|
|
|
# - for any list, list[Union[List, Tuple, int]], must have homogeneous shape
|
|
|
|
|
# - for any tuple, tuple[Union[List, Tuple, int]], must have homogeneous shape
|
|
|
|
|
# 2. Bool indexing is not supported
|
|
|
|
|
# 3. Out of bounds Tensor indexing results in 0
|
|
|
|
|
# - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are out of bounds
|
|
|
|
|
def _getitem(self, indices, v: Optional[Tensor] = None) -> Tensor:
|
|
|
|
|
# wrap single index into a list
|
|
|
|
|
if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)): indices = [indices]
|
|
|
|
|
@@ -1210,6 +1182,43 @@ class Tensor(SimpleMathTrait):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, indices) -> Tensor:
|
|
|
|
|
"""
|
|
|
|
|
Retrieve a sub-tensor using indexing.
|
|
|
|
|
|
|
|
|
|
Supported Index Types: `int | slice | Tensor | None | List | Tuple | Ellipsis`
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
|
|
|
t = Tensor.arange(12).reshape(3, 4)
|
|
|
|
|
print(t.numpy())
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
- Int Indexing: Select an element or sub-tensor using integers for each dimension.
|
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
|
|
|
print(t[1, 2].numpy())
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
- Slice Indexing: Select a range of elements using slice notation (`start:end:stride`).
|
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
|
|
|
print(t[0:2, ::2].numpy())
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
- Tensor Indexing: Use another tensor as indices for advanced indexing. Using `tuple` or `list` here also works.
|
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
|
|
|
print(t[Tensor([2, 0, 1]), Tensor([1, 2, 3])].numpy())
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
- `None` Indexing: Add a new dimension to the tensor.
|
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
|
|
|
print(t[:, None].shape)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
NOTE: Out-of-bounds indexing results in a value of `0`.
|
|
|
|
|
```python exec="true" source="above" session="tensor" result="python"
|
|
|
|
|
t = Tensor([1, 2, 3])
|
|
|
|
|
print(t[Tensor([4, 3, 2])].numpy())
|
|
|
|
|
```
|
|
|
|
|
"""
|
|
|
|
|
return self._getitem(indices)
|
|
|
|
|
|
|
|
|
|
def __setitem__(self, indices, v:Union[Tensor, ConstType]) -> None:
|
|
|
|
|
|