mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
docs: example formatting, multi examples, activation inputs (#4709)
This commit is contained in:
@@ -34,6 +34,9 @@
|
||||
::: tinygrad.Tensor.full
|
||||
::: tinygrad.Tensor.arange
|
||||
::: tinygrad.Tensor.eye
|
||||
::: tinygrad.Tensor.full_like
|
||||
::: tinygrad.Tensor.zeros_like
|
||||
::: tinygrad.Tensor.ones_like
|
||||
|
||||
## Creation (random)
|
||||
|
||||
|
||||
@@ -74,7 +74,9 @@ class Tensor:
|
||||
A `Tensor` is a multi-dimensional matrix containing elements of a single data type.
|
||||
|
||||
```python exec="true" session="tensor"
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, dtypes
|
||||
import numpy as np
|
||||
np.set_printoptions(precision=5)
|
||||
```
|
||||
"""
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
@@ -361,8 +363,10 @@ class Tensor:
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.full((2, 3), 42)
|
||||
print(t.numpy())
|
||||
print(Tensor.full((2, 3), 42).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.full((2, 3), False).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
@@ -376,8 +380,10 @@ class Tensor:
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.zeros(2, 3)
|
||||
print(t.numpy())
|
||||
print(Tensor.zeros(2, 3).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.full(argfix(*shape), 0.0, **kwargs)
|
||||
@@ -391,8 +397,10 @@ class Tensor:
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.ones(2, 3)
|
||||
print(t.numpy())
|
||||
print(Tensor.ones(2, 3).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.full(argfix(*shape), 1.0, **kwargs)
|
||||
@@ -408,18 +416,16 @@ class Tensor:
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(5)
|
||||
print(t.numpy())
|
||||
print(Tensor.arange(5).numpy())
|
||||
```
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(5, 10)
|
||||
print(t.numpy())
|
||||
print(Tensor.arange(5, 10).numpy())
|
||||
```
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(5, 10, 2)
|
||||
print(t.numpy())
|
||||
print(Tensor.arange(5, 10, 2).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.arange(5.5, 10, 2).numpy())
|
||||
```
|
||||
"""
|
||||
if stop is None: stop, start = start, 0
|
||||
@@ -436,8 +442,7 @@ class Tensor:
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.eye(3)
|
||||
print(t.numpy())
|
||||
print(Tensor.eye(3).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
@@ -451,12 +456,12 @@ class Tensor:
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
ot = Tensor.ones(2, 3)
|
||||
t = Tensor.full_like(ot, 42)
|
||||
print(t.numpy())
|
||||
t = Tensor.ones(2, 3)
|
||||
print(Tensor.full_like(t, 42).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
|
||||
def zeros_like(self, **kwargs):
|
||||
"""
|
||||
Creates a tensor with the same shape as `tensor`, filled with zeros.
|
||||
@@ -465,12 +470,12 @@ class Tensor:
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
ot = Tensor.ones(2, 3)
|
||||
t = Tensor.zeros_like(ot)
|
||||
print(t.numpy())
|
||||
t = Tensor.ones(2, 3)
|
||||
print(Tensor.zeros_like(t).numpy())
|
||||
```
|
||||
"""
|
||||
return self.full_like(0, **kwargs)
|
||||
|
||||
def ones_like(self, **kwargs):
|
||||
"""
|
||||
Creates a tensor with the same shape as `tensor`, filled with ones.
|
||||
@@ -479,9 +484,8 @@ class Tensor:
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
ot = Tensor.zeros(2, 3)
|
||||
t = Tensor.ones_like(ot)
|
||||
print(t.numpy())
|
||||
t = Tensor.zeros(2, 3)
|
||||
print(Tensor.ones_like(t).numpy())
|
||||
```
|
||||
"""
|
||||
return self.full_like(1, **kwargs)
|
||||
@@ -499,8 +503,7 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
print(Tensor.randn(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||
@@ -518,8 +521,7 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randint(2, 3, low=5, high=10)
|
||||
print(t.numpy())
|
||||
print(Tensor.randint(2, 3, low=5, high=10).numpy())
|
||||
"""
|
||||
assert dtypes.is_int(dtype := kwargs.pop("dtype", dtypes.int32)), f"Unsupported dtype {dtype} for randint"
|
||||
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
||||
@@ -534,8 +536,7 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.normal(2, 3, mean=10, std=2)
|
||||
print(t.numpy())
|
||||
print(Tensor.normal(2, 3, mean=10, std=2).numpy())
|
||||
```
|
||||
"""
|
||||
return (std * Tensor.randn(*shape, **kwargs)) + mean
|
||||
@@ -550,8 +551,7 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.uniform(2, 3, low=2, high=10)
|
||||
print(t.numpy())
|
||||
print(Tensor.uniform(2, 3, low=2, high=10).numpy())
|
||||
```
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", dtypes.default_float)
|
||||
@@ -568,8 +568,7 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.scaled_uniform(2, 3)
|
||||
print(t.numpy())
|
||||
print(Tensor.scaled_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5)
|
||||
@@ -585,8 +584,7 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.glorot_uniform(2, 3)
|
||||
print(t.numpy())
|
||||
print(Tensor.glorot_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5)
|
||||
@@ -602,8 +600,7 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.kaiming_uniform(2, 3)
|
||||
print(t.numpy())
|
||||
print(Tensor.kaiming_uniform(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
||||
@@ -620,8 +617,7 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.kaiming_normal(2, 3)
|
||||
print(t.numpy())
|
||||
print(Tensor.kaiming_normal(2, 3).numpy())
|
||||
```
|
||||
"""
|
||||
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:]))
|
||||
@@ -653,9 +649,8 @@ class Tensor:
|
||||
Must be used on a scalar tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(6.0, requires_grad=True)
|
||||
t2 = t.sum()
|
||||
t2.backward()
|
||||
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
||||
t.sum().backward()
|
||||
print(t.grad.numpy())
|
||||
```
|
||||
"""
|
||||
@@ -719,7 +714,9 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(6).reshape(2, 3)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.permute(1, 0).numpy())
|
||||
```
|
||||
"""
|
||||
@@ -732,7 +729,9 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(6).reshape(2, 3)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.flip(0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
@@ -749,7 +748,9 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(9).reshape(3, 3)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.shrink(((None, (1, 3)))).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
@@ -768,7 +769,9 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(6).reshape(2, 3)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.pad(((None, (1, 2)))).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
@@ -937,7 +940,9 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([[1, 2], [3, 4]])
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy())
|
||||
```
|
||||
"""
|
||||
@@ -996,7 +1001,6 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 2, 3])
|
||||
print(t.numpy(), "->")
|
||||
print(t.repeat(4, 2).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
@@ -1119,6 +1123,7 @@ class Tensor:
|
||||
def T(self) -> Tensor:
|
||||
"""`.T` is an alias for `.transpose(1, 0)`."""
|
||||
return self.transpose()
|
||||
|
||||
def transpose(self, dim0=1, dim1=0) -> Tensor:
|
||||
"""
|
||||
Returns a tensor that is a transposed version of the original tensor.
|
||||
@@ -1126,7 +1131,9 @@ class Tensor:
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(6).reshape(2, 3)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.transpose(0, 1).numpy())
|
||||
```
|
||||
"""
|
||||
@@ -1191,9 +1198,15 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randint(2, 3, low=5, high=10)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.sum().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.sum(axis=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.sum(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
@@ -1209,9 +1222,15 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randint(2, 3, low=5, high=10)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.max().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.max(axis=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.max(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
@@ -1226,9 +1245,15 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randint(2, 3, low=5, high=10)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.min().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.min(axis=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.min(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
@@ -1244,15 +1269,22 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randint(2, 3, low=5, high=10)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.mean().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.mean(axis=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.mean(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
||||
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
||||
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype)
|
||||
|
||||
def var(self, axis=None, keepdim=False, correction=1):
|
||||
"""
|
||||
Returns the variance of the tensor along the specified axis or axes.
|
||||
@@ -1263,15 +1295,22 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randint(2, 3, low=5, high=10)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.var().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.var(axis=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.var(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
assert all_int(self.shape), "does not support symbolic shape"
|
||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||
return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction))
|
||||
|
||||
def std(self, axis=None, keepdim=False, correction=1):
|
||||
"""
|
||||
Returns the standard deviation of the tensor along the specified axis or axes.
|
||||
@@ -1282,9 +1321,15 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randint(2, 3, low=5, high=10)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.std().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.std(axis=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.std(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
@@ -1306,8 +1351,12 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.softmax().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.softmax(axis=0).numpy())
|
||||
```
|
||||
"""
|
||||
@@ -1325,8 +1374,12 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.log_softmax().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.log_softmax(axis=0).numpy())
|
||||
```
|
||||
"""
|
||||
@@ -1345,9 +1398,15 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.logsumexp().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.logsumexp(axis=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.logsumexp(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
@@ -1364,9 +1423,15 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.argmax().numpy()) # Returns the index of the maximum value in the flattened tensor.
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.argmax(axis=0).numpy()) # Returns the indices of the maximum values along axis 0.
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1.
|
||||
```
|
||||
"""
|
||||
@@ -1377,6 +1442,7 @@ class Tensor:
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32)
|
||||
|
||||
def argmin(self, axis=None, keepdim=False):
|
||||
"""
|
||||
Returns the indices of the minimum value of the tensor along the specified axis.
|
||||
@@ -1387,9 +1453,15 @@ class Tensor:
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy(), "->")
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.argmin().numpy()) # Returns the index of the minimum value in the flattened tensor.
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.argmin(axis=0).numpy()) # Returns the indices of the minimum values along axis 0.
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
|
||||
```
|
||||
"""
|
||||
@@ -1564,12 +1636,7 @@ class Tensor:
|
||||
- Described: https://paperswithcode.com/method/relu
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.relu().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Relu.apply(self)
|
||||
@@ -1580,12 +1647,7 @@ class Tensor:
|
||||
- Described: https://en.wikipedia.org/wiki/Sigmoid_function
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.sigmoid().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
|
||||
@@ -1619,12 +1681,7 @@ class Tensor:
|
||||
- Paper: https://arxiv.org/abs/1511.07289v5
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.elu().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy())
|
||||
```
|
||||
"""
|
||||
return self.relu() - alpha*(1-self.exp()).relu()
|
||||
@@ -1636,12 +1693,7 @@ class Tensor:
|
||||
- Paper: https://arxiv.org/abs/1704.07483
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.celu().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy())
|
||||
```
|
||||
"""
|
||||
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
||||
@@ -1652,12 +1704,7 @@ class Tensor:
|
||||
- Paper: https://arxiv.org/abs/1710.05941v1
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.swish().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy())
|
||||
```
|
||||
"""
|
||||
return self * self.sigmoid()
|
||||
@@ -1669,12 +1716,7 @@ class Tensor:
|
||||
- Paper: https://arxiv.org/abs/1606.08415
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.silu().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy())
|
||||
```
|
||||
"""
|
||||
return self.swish() # The SiLU function is also known as the swish function.
|
||||
@@ -1686,12 +1728,7 @@ class Tensor:
|
||||
- Paper: https://arxiv.org/abs/1704.04861v1
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.relu6().numpy())
|
||||
print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy())
|
||||
```
|
||||
"""
|
||||
return self.relu() - (self-6).relu()
|
||||
@@ -1703,12 +1740,7 @@ class Tensor:
|
||||
- Paper: https://arxiv.org/abs/1905.02244v5
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.hardswish().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy())
|
||||
```
|
||||
"""
|
||||
return self * (self+3).relu6() * (1/6)
|
||||
@@ -1719,12 +1751,7 @@ class Tensor:
|
||||
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.tanh().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy())
|
||||
```
|
||||
"""
|
||||
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
||||
@@ -1735,12 +1762,7 @@ class Tensor:
|
||||
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.sinh().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy())
|
||||
```
|
||||
"""
|
||||
return (self.exp() - self.neg().exp()) / 2
|
||||
@@ -1751,12 +1773,7 @@ class Tensor:
|
||||
- Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.cosh().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy())
|
||||
```
|
||||
"""
|
||||
return (self.exp() + self.neg().exp()) / 2
|
||||
@@ -1767,12 +1784,7 @@ class Tensor:
|
||||
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = (Tensor.rand(2, 3) + 1) / 2
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.atanh().numpy())
|
||||
print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).atanh().numpy())
|
||||
```
|
||||
"""
|
||||
return ((1 + self)/(1 - self)).log() / 2
|
||||
@@ -1783,12 +1795,7 @@ class Tensor:
|
||||
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = (Tensor.rand(2, 3) + 1) / 2
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.asinh().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).asinh().numpy())
|
||||
```
|
||||
"""
|
||||
return (self + (self.square() + 1).sqrt()).log()
|
||||
@@ -1799,12 +1806,7 @@ class Tensor:
|
||||
- Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.rand(2, 3) + 1
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.acosh().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).acosh().numpy())
|
||||
```
|
||||
"""
|
||||
return (self + (self.square() - 1).sqrt()).log()
|
||||
@@ -1815,12 +1817,7 @@ class Tensor:
|
||||
- Described: https://paperswithcode.com/method/hardtanh-activation
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.hardtanh().numpy())
|
||||
print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy())
|
||||
```
|
||||
"""
|
||||
return self.clip(min_val, max_val)
|
||||
@@ -1832,12 +1829,7 @@ class Tensor:
|
||||
- Paper: https://arxiv.org/abs/1606.08415v5
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.gelu().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy())
|
||||
```
|
||||
"""
|
||||
return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh())
|
||||
@@ -1848,12 +1840,7 @@ class Tensor:
|
||||
- Described: https://paperswithcode.com/method/gelu
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.quick_gelu().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy())
|
||||
```
|
||||
"""
|
||||
return self * (self * 1.702).sigmoid()
|
||||
@@ -1864,15 +1851,10 @@ class Tensor:
|
||||
- Described: https://paperswithcode.com/method/leaky-relu
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.leakyrelu().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.leakyrelu(neg_slope=1).numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy())
|
||||
```
|
||||
"""
|
||||
return self.relu() - (-neg_slope*self).relu()
|
||||
@@ -1884,12 +1866,7 @@ class Tensor:
|
||||
- Paper: https://arxiv.org/abs/1908.08681v3
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.mish().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).mish().numpy())
|
||||
```
|
||||
"""
|
||||
return self * self.softplus().tanh()
|
||||
@@ -1900,12 +1877,7 @@ class Tensor:
|
||||
- Described: https://paperswithcode.com/method/softplus
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.softplus().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy())
|
||||
```
|
||||
"""
|
||||
return (1/beta) * (1 + (self*beta).exp()).log()
|
||||
@@ -1916,12 +1888,7 @@ class Tensor:
|
||||
- Described: https://paperswithcode.com/method/softsign
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
Tensor.manual_seed(42)
|
||||
t = Tensor.randn(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.softsign().numpy())
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy())
|
||||
```
|
||||
"""
|
||||
return self / (1 + self.abs())
|
||||
|
||||
Reference in New Issue
Block a user