diff --git a/docs/tensor.md b/docs/tensor.md index f3d5305bd4..7f4c38095f 100644 --- a/docs/tensor.md +++ b/docs/tensor.md @@ -34,6 +34,9 @@ ::: tinygrad.Tensor.full ::: tinygrad.Tensor.arange ::: tinygrad.Tensor.eye +::: tinygrad.Tensor.full_like +::: tinygrad.Tensor.zeros_like +::: tinygrad.Tensor.ones_like ## Creation (random) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ced826116a..3ea4e65896 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -74,7 +74,9 @@ class Tensor: A `Tensor` is a multi-dimensional matrix containing elements of a single data type. ```python exec="true" session="tensor" - from tinygrad import Tensor + from tinygrad import Tensor, dtypes + import numpy as np + np.set_printoptions(precision=5) ``` """ __slots__ = "lazydata", "requires_grad", "grad", "_ctx" @@ -361,8 +363,10 @@ class Tensor: Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.full((2, 3), 42) - print(t.numpy()) + print(Tensor.full((2, 3), 42).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.full((2, 3), False).numpy()) ``` """ return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape) @@ -376,8 +380,10 @@ class Tensor: Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.zeros(2, 3) - print(t.numpy()) + print(Tensor.zeros(2, 3).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.zeros(2, 3, dtype=dtypes.int32).numpy()) ``` """ return Tensor.full(argfix(*shape), 0.0, **kwargs) @@ -391,8 +397,10 @@ class Tensor: Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.ones(2, 3) - print(t.numpy()) + print(Tensor.ones(2, 3).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.ones(2, 3, dtype=dtypes.int32).numpy()) ``` """ return Tensor.full(argfix(*shape), 1.0, **kwargs) @@ -408,18 +416,16 @@ class Tensor: Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(5) - print(t.numpy()) + print(Tensor.arange(5).numpy()) ``` - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(5, 10) - print(t.numpy()) + print(Tensor.arange(5, 10).numpy()) ``` - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(5, 10, 2) - print(t.numpy()) + print(Tensor.arange(5, 10, 2).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.arange(5.5, 10, 2).numpy()) ``` """ if stop is None: stop, start = start, 0 @@ -436,8 +442,7 @@ class Tensor: Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.eye(3) - print(t.numpy()) + print(Tensor.eye(3).numpy()) ``` """ return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim) @@ -451,12 +456,12 @@ class Tensor: Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" - ot = Tensor.ones(2, 3) - t = Tensor.full_like(ot, 42) - print(t.numpy()) + t = Tensor.ones(2, 3) + print(Tensor.full_like(t, 42).numpy()) ``` """ return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) + def zeros_like(self, **kwargs): """ Creates a tensor with the same shape as `tensor`, filled with zeros. @@ -465,12 +470,12 @@ class Tensor: Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" - ot = Tensor.ones(2, 3) - t = Tensor.zeros_like(ot) - print(t.numpy()) + t = Tensor.ones(2, 3) + print(Tensor.zeros_like(t).numpy()) ``` """ return self.full_like(0, **kwargs) + def ones_like(self, **kwargs): """ Creates a tensor with the same shape as `tensor`, filled with ones. @@ -479,9 +484,8 @@ class Tensor: Additionally, all other keyword arguments are passed to the constructor of the tensor. ```python exec="true" source="above" session="tensor" result="python" - ot = Tensor.zeros(2, 3) - t = Tensor.ones_like(ot) - print(t.numpy()) + t = Tensor.zeros(2, 3) + print(Tensor.ones_like(t).numpy()) ``` """ return self.full_like(1, **kwargs) @@ -499,8 +503,7 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) + print(Tensor.randn(2, 3).numpy()) ``` """ # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform @@ -518,8 +521,7 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) - t = Tensor.randint(2, 3, low=5, high=10) - print(t.numpy()) + print(Tensor.randint(2, 3, low=5, high=10).numpy()) """ assert dtypes.is_int(dtype := kwargs.pop("dtype", dtypes.int32)), f"Unsupported dtype {dtype} for randint" return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs) @@ -534,8 +536,7 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) - t = Tensor.normal(2, 3, mean=10, std=2) - print(t.numpy()) + print(Tensor.normal(2, 3, mean=10, std=2).numpy()) ``` """ return (std * Tensor.randn(*shape, **kwargs)) + mean @@ -550,8 +551,7 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) - t = Tensor.uniform(2, 3, low=2, high=10) - print(t.numpy()) + print(Tensor.uniform(2, 3, low=2, high=10).numpy()) ``` """ dtype = kwargs.pop("dtype", dtypes.default_float) @@ -568,8 +568,7 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) - t = Tensor.scaled_uniform(2, 3) - print(t.numpy()) + print(Tensor.scaled_uniform(2, 3).numpy()) ``` """ return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5) @@ -585,8 +584,7 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) - t = Tensor.glorot_uniform(2, 3) - print(t.numpy()) + print(Tensor.glorot_uniform(2, 3).numpy()) ``` """ return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5) @@ -602,8 +600,7 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) - t = Tensor.kaiming_uniform(2, 3) - print(t.numpy()) + print(Tensor.kaiming_uniform(2, 3).numpy()) ``` """ bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) @@ -620,8 +617,7 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) - t = Tensor.kaiming_normal(2, 3) - print(t.numpy()) + print(Tensor.kaiming_normal(2, 3).numpy()) ``` """ std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) @@ -653,9 +649,8 @@ class Tensor: Must be used on a scalar tensor. ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(6.0, requires_grad=True) - t2 = t.sum() - t2.backward() + t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True) + t.sum().backward() print(t.grad.numpy()) ``` """ @@ -719,7 +714,9 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.permute(1, 0).numpy()) ``` """ @@ -732,7 +729,9 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.flip(0).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" @@ -749,7 +748,9 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(9).reshape(3, 3) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.shrink(((None, (1, 3)))).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" @@ -768,7 +769,9 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.pad(((None, (1, 2)))).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" @@ -937,7 +940,9 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" t = Tensor([[1, 2], [3, 4]]) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.gather(1, Tensor([[0, 0], [1, 0]])).numpy()) ``` """ @@ -996,7 +1001,6 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" t = Tensor([1, 2, 3]) - print(t.numpy(), "->") print(t.repeat(4, 2).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" @@ -1119,6 +1123,7 @@ class Tensor: def T(self) -> Tensor: """`.T` is an alias for `.transpose(1, 0)`.""" return self.transpose() + def transpose(self, dim0=1, dim1=0) -> Tensor: """ Returns a tensor that is a transposed version of the original tensor. @@ -1126,7 +1131,9 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" t = Tensor.arange(6).reshape(2, 3) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.transpose(0, 1).numpy()) ``` """ @@ -1191,9 +1198,15 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randint(2, 3, low=5, high=10) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.sum().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.sum(axis=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.sum(axis=1).numpy()) ``` """ @@ -1209,9 +1222,15 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randint(2, 3, low=5, high=10) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.max().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.max(axis=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.max(axis=1).numpy()) ``` """ @@ -1226,9 +1245,15 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randint(2, 3, low=5, high=10) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.min().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.min(axis=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.min(axis=1).numpy()) ``` """ @@ -1244,15 +1269,22 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randint(2, 3, low=5, high=10) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.mean().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.mean(axis=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.mean(axis=1).numpy()) ``` """ output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32 numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim) return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if si != so])).cast(output_dtype) + def var(self, axis=None, keepdim=False, correction=1): """ Returns the variance of the tensor along the specified axis or axes. @@ -1263,15 +1295,22 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randint(2, 3, low=5, high=10) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.var().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.var(axis=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.var(axis=1).numpy()) ``` """ assert all_int(self.shape), "does not support symbolic shape" square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction)) + def std(self, axis=None, keepdim=False, correction=1): """ Returns the standard deviation of the tensor along the specified axis or axes. @@ -1282,9 +1321,15 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randint(2, 3, low=5, high=10) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.std().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.std(axis=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.std(axis=1).numpy()) ``` """ @@ -1306,8 +1351,12 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(2, 3) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.softmax().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.softmax(axis=0).numpy()) ``` """ @@ -1325,8 +1374,12 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(2, 3) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.log_softmax().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.log_softmax(axis=0).numpy()) ``` """ @@ -1345,9 +1398,15 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(2, 3) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.logsumexp().numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.logsumexp(axis=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.logsumexp(axis=1).numpy()) ``` """ @@ -1364,9 +1423,15 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(2, 3) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.argmax().numpy()) # Returns the index of the maximum value in the flattened tensor. + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.argmax(axis=0).numpy()) # Returns the indices of the maximum values along axis 0. + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.argmax(axis=1).numpy()) # Returns the indices of the maximum values along axis 1. ``` """ @@ -1377,6 +1442,7 @@ class Tensor: m = self == self.max(axis=axis, keepdim=True) idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32) + def argmin(self, axis=None, keepdim=False): """ Returns the indices of the minimum value of the tensor along the specified axis. @@ -1387,9 +1453,15 @@ class Tensor: ```python exec="true" source="above" session="tensor" result="python" Tensor.manual_seed(42) t = Tensor.randn(2, 3) - print(t.numpy(), "->") + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.argmin().numpy()) # Returns the index of the minimum value in the flattened tensor. + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.argmin(axis=0).numpy()) # Returns the indices of the minimum values along axis 0. + ``` + ```python exec="true" source="above" session="tensor" result="python" print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1. ``` """ @@ -1564,12 +1636,7 @@ class Tensor: - Described: https://paperswithcode.com/method/relu ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.relu().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy()) ``` """ return F.Relu.apply(self) @@ -1580,12 +1647,7 @@ class Tensor: - Described: https://en.wikipedia.org/wiki/Sigmoid_function ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.sigmoid().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sigmoid().numpy()) ``` """ return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype))) @@ -1619,12 +1681,7 @@ class Tensor: - Paper: https://arxiv.org/abs/1511.07289v5 ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.elu().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).elu().numpy()) ``` """ return self.relu() - alpha*(1-self.exp()).relu() @@ -1636,12 +1693,7 @@ class Tensor: - Paper: https://arxiv.org/abs/1704.07483 ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.celu().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).celu().numpy()) ``` """ return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0) @@ -1652,12 +1704,7 @@ class Tensor: - Paper: https://arxiv.org/abs/1710.05941v1 ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.swish().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).swish().numpy()) ``` """ return self * self.sigmoid() @@ -1669,12 +1716,7 @@ class Tensor: - Paper: https://arxiv.org/abs/1606.08415 ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.silu().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).silu().numpy()) ``` """ return self.swish() # The SiLU function is also known as the swish function. @@ -1686,12 +1728,7 @@ class Tensor: - Paper: https://arxiv.org/abs/1704.04861v1 ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.relu6().numpy()) + print(Tensor([-9., -6., -3., 0., 3., 6., 9.]).relu6().numpy()) ``` """ return self.relu() - (self-6).relu() @@ -1703,12 +1740,7 @@ class Tensor: - Paper: https://arxiv.org/abs/1905.02244v5 ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.hardswish().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).hardswish().numpy()) ``` """ return self * (self+3).relu6() * (1/6) @@ -1719,12 +1751,7 @@ class Tensor: - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Tanh ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.tanh().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).tanh().numpy()) ``` """ return 2.0 * ((2.0 * self).sigmoid()) - 1.0 @@ -1735,12 +1762,7 @@ class Tensor: - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Sinh ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.sinh().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sinh().numpy()) ``` """ return (self.exp() - self.neg().exp()) / 2 @@ -1751,12 +1773,7 @@ class Tensor: - Described: https://en.wikipedia.org/wiki/Hyperbolic_functions#Cosh ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.cosh().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).cosh().numpy()) ``` """ return (self.exp() + self.neg().exp()) / 2 @@ -1767,12 +1784,7 @@ class Tensor: - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#atanh ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = (Tensor.rand(2, 3) + 1) / 2 - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.atanh().numpy()) + print(Tensor([-0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9]).atanh().numpy()) ``` """ return ((1 + self)/(1 - self)).log() / 2 @@ -1783,12 +1795,7 @@ class Tensor: - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#asinh ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = (Tensor.rand(2, 3) + 1) / 2 - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.asinh().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).asinh().numpy()) ``` """ return (self + (self.square() + 1).sqrt()).log() @@ -1799,12 +1806,7 @@ class Tensor: - Described: https://en.wikipedia.org/wiki/Inverse_hyperbolic_functions#acosh ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.rand(2, 3) + 1 - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.acosh().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).acosh().numpy()) ``` """ return (self + (self.square() - 1).sqrt()).log() @@ -1815,12 +1817,7 @@ class Tensor: - Described: https://paperswithcode.com/method/hardtanh-activation ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.hardtanh().numpy()) + print(Tensor([-1.5, -1.0, -0.5, 0., 0.5, 1.0, 1.5]).hardtanh().numpy()) ``` """ return self.clip(min_val, max_val) @@ -1832,12 +1829,7 @@ class Tensor: - Paper: https://arxiv.org/abs/1606.08415v5 ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.gelu().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).gelu().numpy()) ``` """ return 0.5 * self * (1 + (self * 0.7978845608 * (1 + 0.044715 * self * self)).tanh()) @@ -1848,12 +1840,7 @@ class Tensor: - Described: https://paperswithcode.com/method/gelu ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.quick_gelu().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).quick_gelu().numpy()) ``` """ return self * (self * 1.702).sigmoid() @@ -1864,15 +1851,10 @@ class Tensor: - Described: https://paperswithcode.com/method/leaky-relu ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu().numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" - print(t.leakyrelu().numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.leakyrelu(neg_slope=1).numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).leakyrelu(neg_slope=0.42).numpy()) ``` """ return self.relu() - (-neg_slope*self).relu() @@ -1884,12 +1866,7 @@ class Tensor: - Paper: https://arxiv.org/abs/1908.08681v3 ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.mish().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).mish().numpy()) ``` """ return self * self.softplus().tanh() @@ -1900,12 +1877,7 @@ class Tensor: - Described: https://paperswithcode.com/method/softplus ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.softplus().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softplus().numpy()) ``` """ return (1/beta) * (1 + (self*beta).exp()).log() @@ -1916,12 +1888,7 @@ class Tensor: - Described: https://paperswithcode.com/method/softsign ```python exec="true" source="above" session="tensor" result="python" - Tensor.manual_seed(42) - t = Tensor.randn(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.softsign().numpy()) + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).softsign().numpy()) ``` """ return self / (1 + self.abs())