diff --git a/docs/tensor.md b/docs/tensor.md index f32813d2c9..ffa5f5f55b 100644 --- a/docs/tensor.md +++ b/docs/tensor.md @@ -99,16 +99,16 @@ ## Processing +::: tinygrad.Tensor.avg_pool2d +::: tinygrad.Tensor.max_pool2d ::: tinygrad.Tensor.conv2d +::: tinygrad.Tensor.conv_transpose2d ::: tinygrad.Tensor.dot ::: tinygrad.Tensor.matmul ::: tinygrad.Tensor.einsum ::: tinygrad.Tensor.cumsum ::: tinygrad.Tensor.triu ::: tinygrad.Tensor.tril -::: tinygrad.Tensor.avg_pool2d -::: tinygrad.Tensor.max_pool2d -::: tinygrad.Tensor.conv_transpose2d ## Unary Ops (math) @@ -118,6 +118,11 @@ ::: tinygrad.Tensor.log2 ::: tinygrad.Tensor.exp ::: tinygrad.Tensor.exp2 +::: tinygrad.Tensor.sqrt +::: tinygrad.Tensor.rsqrt +::: tinygrad.Tensor.sin +::: tinygrad.Tensor.cos +::: tinygrad.Tensor.tan ::: tinygrad.Tensor.trunc ::: tinygrad.Tensor.ceil ::: tinygrad.Tensor.floor diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 31ac402395..92e6d4cccd 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -76,6 +76,7 @@ class Tensor: ```python exec="true" session="tensor" from tinygrad import Tensor, dtypes import numpy as np + import math np.set_printoptions(precision=4) ``` """ @@ -168,6 +169,9 @@ class Tensor: return self def replace(self, x:Tensor) -> Tensor: + """ + Replace the data of this tensor with the data of another tensor. Only the shape of the tensors must match. + """ # used for replacing a Tensor with a new version of it (potentially with a different device and dtype) assert not x.requires_grad and getattr(self, '_ctx', None) is None assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}" @@ -192,7 +196,11 @@ class Tensor: if not self.lazydata.is_realized(): return self.replace(x) self.lazydata = self.lazydata.assign(x.lazydata) return self - def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False) + def detach(self) -> Tensor: + """ + Returns a new tensor with the same data as this tensor, but detached from the autograd graph. + """ + return Tensor(self.lazydata, device=self.device, requires_grad=False) def _data(self) -> memoryview: if 0 in self.shape: return memoryview(bytearray(0)) @@ -203,6 +211,14 @@ class Tensor: return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False) def data(self) -> memoryview: + """ + Returns the data of this tensor as a memoryview. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([1, 2, 3, 4]) + print(np.frombuffer(t.data(), dtype=np.int32)) + ``` + """ assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}" assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" return self._data().cast(self.dtype.fmt, self.shape) @@ -245,6 +261,9 @@ class Tensor: return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape) def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor: + """ + Moves the tensor to the given device. + """ device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) if device == self.device: return self if not isinstance(device, str): return self.shard(device) @@ -254,18 +273,27 @@ class Tensor: return ret def to_(self, device:Optional[Union[str, Tuple[str, ...]]]): + """ + Moves the tensor to the given device in place. + """ real = self.to(device) # TODO: is this assign? if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata self.lazydata = real.lazydata def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor: + """ + Shards the tensor across the given devices. + """ assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer" canonical_devices = tuple(Device.canonicalize(x) for x in devices) if axis is not None and axis < 0: axis += len(self.shape) return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad) def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None): + """ + Shards the tensor across the given devices in place. + """ self.lazydata = self.shard(devices, axis).lazydata return self @@ -1462,6 +1490,17 @@ class Tensor: @staticmethod def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor: + """ + Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention. + + See: https://pytorch.org/docs/stable/generated/torch.einsum.html + + ```python exec="true" source="above" session="tensor" result="python" + x = Tensor([[1, 2], [3, 4]]) + y = Tensor([[5, 6], [7, 8]]) + print(Tensor.einsum("ij,ij->", x, y).numpy()) + ``` + """ xs:Tuple[Tensor] = argfix(*raw_xs) formula = formula.replace(" ", "") inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula)) @@ -1512,25 +1551,51 @@ class Tensor: return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))]) # NOTE: these work for more than 2D - def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool( - make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) - def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool( - make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) + def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): + """ + Applies average pooling over a tensor. - def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor: - HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1)) - x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing) - stride = make_pair(stride, len(HW)) - if any(s>1 for s in stride): - x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:])) - x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride))) - x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)]) - x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) - padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list( - zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) - return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) + NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions. + + See: https://paperswithcode.com/method/average-pooling + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(9).reshape(1, 1, 3, 3) + print(t.avg_pool2d().numpy()) + ``` + """ + return self._pool( + make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) + def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): + """ + Applies max pooling over a tensor. + + NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions. + + See: https://paperswithcode.com/method/max-pooling + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(9).reshape(1, 1, 3, 3) + print(t.max_pool2d().numpy()) + ``` + """ + return self._pool( + make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor: + """ + Applies a convolution over a tensor with a given weight and optional bias. + + NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions. + + See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(9).reshape(1, 1, 3, 3) + w = Tensor.ones(1, 1, 2, 2) + print(t.conv2d(w).numpy()) + ``` + """ (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501 if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501 @@ -1578,7 +1643,42 @@ class Tensor: return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() + def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor: + """ + Applies a transposed convolution over a tensor with a given weight and optional bias. + + NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions. + + See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(9).reshape(1, 1, 3, 3) + w = Tensor.ones(1, 1, 2, 2) + print(t.conv_transpose2d(w).numpy()) + ``` + """ + HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1)) + x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing) + stride = make_pair(stride, len(HW)) + if any(s>1 for s in stride): + x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:])) + x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride))) + x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)]) + x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) + padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list( + zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) + return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) + def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor: + """ + Performs dot product between two tensors. + + ```python exec="true" source="above" session="tensor" result="python" + a = Tensor([[1, 2], [3, 4]]) + b = Tensor([[5, 6], [7, 8]]) + print(a.dot(b).numpy()) + ``` + """ n1, n2 = len(self.shape), len(w.shape) assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})" @@ -1587,12 +1687,37 @@ class Tensor: return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype)) def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor: + """ + Performs matrix multiplication between two tensors. + + You can pass in the `reverse` keyword argument to control the order of the matrix multiplication. + You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation. + + ```python exec="true" source="above" session="tensor" result="python" + a = Tensor([[1, 2], [3, 4]]) + b = Tensor([[5, 6], [7, 8]]) + print(a.matmul(b).numpy()) + ``` + """ return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype) def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0) return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1) def cumsum(self, axis:int=0) -> Tensor: + """ + Computes the cumulative sum of the tensor along the specified axis. + + You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.ones(2, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.cumsum(1).numpy()) + ``` + """ # TODO: someday the optimizer will find this on it's own # for now this is a two stage cumsum SPLIT = 256 @@ -1609,19 +1734,107 @@ class Tensor: assert all_int((r,c)), "does not support symbolic" if r == 0: return Tensor.zeros((r, c), **kwargs) return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) - def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0) - def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self) + def triu(self, k:int=0) -> Tensor: + """ + Returns the upper triangular part of the tensor, the other elements are set to 0. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[1, 2, 3], [4, 5, 6]]) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.triu(k=1).numpy()) + ``` + """ + return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0) + def tril(self, k:int=0) -> Tensor: + """ + Returns the lower triangular part of the tensor, the other elements are set to 0. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[1, 2, 3], [4, 5, 6]]) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.tril().numpy()) + ``` + """ + return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self) # ***** unary ops ***** - def logical_not(self): return F.Eq.apply(*self._broadcasted(False)) - def neg(self): return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not() - def contiguous(self): return F.Contiguous.apply(self) - def contiguous_backward(self): return F.ContiguousBackward.apply(self) - def log(self): return F.Log.apply(self.cast(least_upper_float(self.dtype))) - def log2(self): return self.log()/math.log(2) - def exp(self): return F.Exp.apply(self.cast(least_upper_float(self.dtype))) - def exp2(self): return F.Exp.apply(self*math.log(2)) + def logical_not(self): + """ + Computes the logical NOT of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([False, True]).logical_not().numpy()) + ``` + """ + return F.Eq.apply(*self._broadcasted(False)) + def neg(self): + """ + Negates the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy()) + ``` + """ + return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not() + def contiguous(self): + """ + Returns a contiguous tensor. + """ + return F.Contiguous.apply(self) + def contiguous_backward(self): + """ + Inserts a contiguous operation in the backward pass. + """ + return F.ContiguousBackward.apply(self) + def log(self): + """ + Computes the natural logarithm element-wise. + + See: https://en.wikipedia.org/wiki/Logarithm + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 4., 8.]).log().numpy()) + ``` + """ + return F.Log.apply(self.cast(least_upper_float(self.dtype))) + def log2(self): + """ + Computes the base-2 logarithm element-wise. + + See: https://en.wikipedia.org/wiki/Logarithm + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 4., 8.]).log2().numpy()) + ``` + """ + return self.log()/math.log(2) + def exp(self): + """ + Computes the exponential function element-wise. + + See: https://en.wikipedia.org/wiki/Exponential_function + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0., 1., 2., 3.]).exp().numpy()) + ``` + """ + return F.Exp.apply(self.cast(least_upper_float(self.dtype))) + def exp2(self): + """ + Computes the base-2 exponential function element-wise. + + See: https://en.wikipedia.org/wiki/Exponential_function + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0., 1., 2., 3.]).exp2().numpy()) + ``` + """ + return F.Exp.apply(self*math.log(2)) def relu(self): """ Applies the Rectified Linear Unit (ReLU) function element-wise. @@ -1644,25 +1857,145 @@ class Tensor: ``` """ return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype))) - def sin(self): return F.Sin.apply(self.cast(least_upper_float(self.dtype))) - def sqrt(self): return F.Sqrt.apply(self.cast(least_upper_float(self.dtype))) - def rsqrt(self): return self.reciprocal().sqrt() - def cos(self): return ((math.pi/2)-self).sin() - def tan(self): return self.sin() / self.cos() + def sqrt(self): + """ + Computes the square root of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 3.]).sqrt().numpy()) + ``` + """ + return F.Sqrt.apply(self.cast(least_upper_float(self.dtype))) + def rsqrt(self): + """ + Computes the reciprocal of the square root of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 3.]).rsqrt().numpy()) + ``` + """ + return self.reciprocal().sqrt() + def sin(self): + """ + Computes the sine of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy()) + ``` + """ + return F.Sin.apply(self.cast(least_upper_float(self.dtype))) + def cos(self): + """ + Computes the cosine of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy()) + ``` + """ + return ((math.pi/2)-self).sin() + def tan(self): + """ + Computes the tangent of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy()) + ``` + """ + return self.sin() / self.cos() # ***** math functions ***** - def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).cast(self.dtype) - def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b) - def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b) + def trunc(self: Tensor) -> Tensor: + """ + Truncates the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).trunc().numpy()) + ``` + """ + return self.cast(dtypes.int32).cast(self.dtype) + def ceil(self: Tensor) -> Tensor: + """ + Rounds the tensor element-wise towards positive infinity. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).ceil().numpy()) + ``` + """ + return (self > (b := self.trunc())).where(b+1, b) + def floor(self: Tensor) -> Tensor: + """ + Rounds the tensor element-wise towards negative infinity. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).floor().numpy()) + ``` + """ + return (self < (b := self.trunc())).where(b-1, b) def round(self: Tensor) -> Tensor: + """ + Rounds the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).round().numpy()) + ``` + """ return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor()) - def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor: return self + (end - self) * weight - def square(self): return self*self - def clip(self, min_, max_): return self.maximum(min_).minimum(max_) - def sign(self): return F.Sign.apply(self) - def abs(self): return self * self.sign() - def reciprocal(self): return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype))) + def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor: + """ + Linearly interpolates between `self` and `end` by `weight`. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy()) + ``` + """ + return self + (end - self) * weight + def square(self): + """ + Convenience method for squaring the tensor element-wise. + Equivalent to `self*self`. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy()) + ``` + """ + return self*self + def clip(self, min_, max_): + """ + Clips (limits) the values in the tensor between `min_` and `max_` element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy()) + ``` + """ + return self.maximum(min_).minimum(max_) + def sign(self): + """ + Returns the sign of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy()) + ``` + """ + return F.Sign.apply(self) + def abs(self): + """ + Computes the absolute value of the tensor element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy()) + ``` + """ + return self * self.sign() + def reciprocal(self): + """ + Compute `1/x` element-wise. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([1., 2., 3.]).reciprocal().numpy()) + ``` + """ + return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype))) # ***** activation functions ***** @@ -2191,16 +2524,67 @@ class Tensor: # ***** functional nn ops ***** def linear(self, weight:Tensor, bias:Optional[Tensor]=None): + """ + Applies a linear transformation to `self` using `weight` and `bias`. + + See: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[1, 2], [3, 4]]) + weight = Tensor([[1, 2], [3, 4]]) + bias = Tensor([1, 2]) + print(t.linear(weight, bias).numpy()) + ``` + """ x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) return x.add(bias) if bias is not None else x - def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self) + def sequential(self, ll:List[Callable[[Tensor], Tensor]]): + """ + Applies a sequence of functions to `self` chaining the output of each function to the input of the next. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([1, 2, 3]) + print(t.sequential([lambda x: x * 2, lambda x: x + 1]).numpy()) + ``` + """ + return functools.reduce(lambda x,f: f(x), ll, self) def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor: + """ + Applies Layer Normalization over a mini-batch of inputs. + + - Described: https://paperswithcode.com/method/layer-normalization + - Paper: https://arxiv.org/abs/1607.06450v1 + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.randn(8, 10, 16) * 2 + 8 + print(t.mean().item(), t.std().item()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = t.layernorm() + print(t.mean().item(), t.std().item()) + ``` + """ y = (self - self.mean(axis, keepdim=True)) return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt()) def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor: + """ + Applies Batch Normalization over a mini-batch of inputs. + + - Described: https://paperswithcode.com/method/batch-normalization + - Paper: https://arxiv.org/abs/1502.03167 + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.randn(8, 4, 16, 16) * 2 + 8 + print(t.mean().item(), t.std().item()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = t.batchnorm(None, None, t.mean(axis=(0,2,3)), t.var(axis=(0,2,3)).add(1e-5).rsqrt()) + print(t.mean().item(), t.std().item()) + ``` + """ axis_ = argfix(axis) shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape)) x = self - mean.reshape(shape) @@ -2209,15 +2593,53 @@ class Tensor: return (ret + bias.reshape(shape)) if bias is not None else ret def dropout(self, p=0.5) -> Tensor: + """ + Applies dropout to `self`. + + NOTE: dropout is only applied when `Tensor.training` is `True`. + + - Described: https://paperswithcode.com/method/dropout + - Paper: https://jmlr.org/papers/v15/srivastava14a.html + + ```python exec="true" source="above" session="tensor" result="python" + Tensor.manual_seed(42) + t = Tensor.randn(2, 2) + with Tensor.train(): + print(t.dropout().numpy()) + ``` + """ if not Tensor.training or p == 0: return self return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p)) def one_hot(self, num_classes:int) -> Tensor: + """ + Converts `self` to a one-hot tensor. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([0, 1, 3, 3, 4]) + print(t.one_hot(5).numpy()) + ``` + """ return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0) def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: - # NOTE: it works if key, value have symbolic shape + """ + Computes scaled dot-product attention. + `self` is the query tensor, `key` is the key tensor, and `value` is the value tensor. + + NOTE: it also works when `key` and `value` have symbolic shape. + + - Described: https://paperswithcode.com/method/scaled + - Paper: https://arxiv.org/abs/1706.03762v7 + + ```python exec="true" source="above" session="tensor" result="python" + q = Tensor.randn(2, 4, 8) + k = Tensor.randn(2, 4, 8) + v = Tensor.randn(2, 4, 8) + print(q.scaled_dot_product_attention(k, v).numpy()) + ``` + """ assert all_int(self.shape), f"does not support symbolic shape {self.shape}" if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool) if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0) @@ -2225,14 +2647,48 @@ class Tensor: return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value def binary_crossentropy(self, y:Tensor) -> Tensor: + """ + Computes the binary cross-entropy loss between `self` and `y`. + + See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([0.1, 0.9, 0.2]) + y = Tensor([0, 1, 0]) + print(t.binary_crossentropy(y).item()) + ``` + """ return (-y*self.log() - (1-y)*(1-self).log()).mean() def binary_crossentropy_logits(self, y:Tensor) -> Tensor: + """ + Computes the binary cross-entropy loss between `self` and `y` where `self` is logits. + + See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([-1, 2, -3]) + y = Tensor([0, 1, 0]) + print(t.binary_crossentropy_logits(y).item()) + ``` + """ return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean() def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor: + """ + Computes the sparse categorical cross-entropy loss between `self` and `Y`. + + NOTE: `self` is logits and `Y` is the target labels. + + See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([[-1, 2, -3], [1, -2, 3]]) + Y = Tensor([1, 2]) + print(t.sparse_categorical_crossentropy(Y).item()) + ``` + """ assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]" - # NOTE: self is a logits input log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) @@ -2245,12 +2701,65 @@ class Tensor: # hack for devices that don't support bfloat16 assert self.dtype == dtypes.bfloat16 return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype) - def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype) + def cast(self, dtype:DType) -> Tensor: + """ + Casts `self` to a new dtype. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([-1, 2.5, 3], dtype=dtypes.float) + print(t.dtype, t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = t.cast(dtypes.int32) + print(t.dtype, t.numpy()) + ``` + """ + return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype) def bitcast(self, dtype:DType) -> Tensor: + """ + Bitcasts `self` to a new dtype of the same itemsize. + + `self` must not require a gradient. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([-1, 2, 3], dtype=dtypes.int32) + print(t.dtype, t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = t.bitcast(dtypes.uint32) + print(t.dtype, t.numpy()) + ``` + """ if self.requires_grad: raise RuntimeError("can't backprop through bitcast") return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self - def float(self) -> Tensor: return self.cast(dtypes.float32) - def half(self) -> Tensor: return self.cast(dtypes.float16) + def float(self) -> Tensor: + """ + Convenience method to cast `self` to a `float32` Tensor. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([-1, 2, 3], dtype=dtypes.int32) + print(t.dtype, t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = t.float() + print(t.dtype, t.numpy()) + ``` + """ + return self.cast(dtypes.float32) + def half(self) -> Tensor: + """ + Convenience method to cast `self` to a `float16` Tensor. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([-1, 2, 3], dtype=dtypes.int32) + print(t.dtype, t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = t.half() + print(t.dtype, t.numpy()) + ``` + """ + return self.cast(dtypes.float16) # ***** convenience stuff *****