finish tensor docs (#4722)

This commit is contained in:
wozeparrot
2024-05-24 22:57:43 +00:00
committed by GitHub
parent edf27470c1
commit 5f503226de
2 changed files with 564 additions and 50 deletions

View File

@@ -99,16 +99,16 @@
## Processing
::: tinygrad.Tensor.avg_pool2d
::: tinygrad.Tensor.max_pool2d
::: tinygrad.Tensor.conv2d
::: tinygrad.Tensor.conv_transpose2d
::: tinygrad.Tensor.dot
::: tinygrad.Tensor.matmul
::: tinygrad.Tensor.einsum
::: tinygrad.Tensor.cumsum
::: tinygrad.Tensor.triu
::: tinygrad.Tensor.tril
::: tinygrad.Tensor.avg_pool2d
::: tinygrad.Tensor.max_pool2d
::: tinygrad.Tensor.conv_transpose2d
## Unary Ops (math)
@@ -118,6 +118,11 @@
::: tinygrad.Tensor.log2
::: tinygrad.Tensor.exp
::: tinygrad.Tensor.exp2
::: tinygrad.Tensor.sqrt
::: tinygrad.Tensor.rsqrt
::: tinygrad.Tensor.sin
::: tinygrad.Tensor.cos
::: tinygrad.Tensor.tan
::: tinygrad.Tensor.trunc
::: tinygrad.Tensor.ceil
::: tinygrad.Tensor.floor

View File

@@ -76,6 +76,7 @@ class Tensor:
```python exec="true" session="tensor"
from tinygrad import Tensor, dtypes
import numpy as np
import math
np.set_printoptions(precision=4)
```
"""
@@ -168,6 +169,9 @@ class Tensor:
return self
def replace(self, x:Tensor) -> Tensor:
"""
Replace the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
"""
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
assert not x.requires_grad and getattr(self, '_ctx', None) is None
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
@@ -192,7 +196,11 @@ class Tensor:
if not self.lazydata.is_realized(): return self.replace(x)
self.lazydata = self.lazydata.assign(x.lazydata)
return self
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
def detach(self) -> Tensor:
"""
Returns a new tensor with the same data as this tensor, but detached from the autograd graph.
"""
return Tensor(self.lazydata, device=self.device, requires_grad=False)
def _data(self) -> memoryview:
if 0 in self.shape: return memoryview(bytearray(0))
@@ -203,6 +211,14 @@ class Tensor:
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
def data(self) -> memoryview:
"""
Returns the data of this tensor as a memoryview.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3, 4])
print(np.frombuffer(t.data(), dtype=np.int32))
```
"""
assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return self._data().cast(self.dtype.fmt, self.shape)
@@ -245,6 +261,9 @@ class Tensor:
return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape)
def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor:
"""
Moves the tensor to the given device.
"""
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
if device == self.device: return self
if not isinstance(device, str): return self.shard(device)
@@ -254,18 +273,27 @@ class Tensor:
return ret
def to_(self, device:Optional[Union[str, Tuple[str, ...]]]):
"""
Moves the tensor to the given device in place.
"""
real = self.to(device)
# TODO: is this assign?
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
self.lazydata = real.lazydata
def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor:
"""
Shards the tensor across the given devices.
"""
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
canonical_devices = tuple(Device.canonicalize(x) for x in devices)
if axis is not None and axis < 0: axis += len(self.shape)
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad)
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None):
"""
Shards the tensor across the given devices in place.
"""
self.lazydata = self.shard(devices, axis).lazydata
return self
@@ -1462,6 +1490,17 @@ class Tensor:
@staticmethod
def einsum(formula:str, *raw_xs, acc_dtype:Optional[DType]=None) -> Tensor:
"""
Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention.
See: https://pytorch.org/docs/stable/generated/torch.einsum.html
```python exec="true" source="above" session="tensor" result="python"
x = Tensor([[1, 2], [3, 4]])
y = Tensor([[5, 6], [7, 8]])
print(Tensor.einsum("ij,ij->", x, y).numpy())
```
"""
xs:Tuple[Tensor] = argfix(*raw_xs)
formula = formula.replace(" ", "")
inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula))
@@ -1512,25 +1551,51 @@ class Tensor:
return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))])
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
"""
Applies average pooling over a tensor.
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
stride = make_pair(stride, len(HW))
if any(s>1 for s in stride):
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(
zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
See: https://paperswithcode.com/method/average-pooling
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
print(t.avg_pool2d().numpy())
```
"""
return self._pool(
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1):
"""
Applies max pooling over a tensor.
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
See: https://paperswithcode.com/method/max-pooling
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
print(t.max_pool2d().numpy())
```
"""
return self._pool(
make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0)))
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor:
"""
Applies a convolution over a tensor with a given weight and optional bias.
NOTE: unlike PyTorch, this implementation is not limited to only 2d convolutions and instead works for any number of dimensions.
See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
w = Tensor.ones(1, 1, 2, 2)
print(t.conv2d(w).numpy())
```
"""
(bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:]
assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501
if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501
@@ -1578,7 +1643,42 @@ class Tensor:
return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward()
def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor:
"""
Applies a transposed convolution over a tensor with a given weight and optional bias.
NOTE: unlike PyTorch, this implementation is not limited to only 2d transposed convolutions and instead works for any number of dimensions.
See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.arange(9).reshape(1, 1, 3, 3)
w = Tensor.ones(1, 1, 2, 2)
print(t.conv_transpose2d(w).numpy())
```
"""
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing)
stride = make_pair(stride, len(HW))
if any(s>1 for s in stride):
x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:]))
x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride)))
x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)])
x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)]))
padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(
zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW)))))))
return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding)
def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor:
"""
Performs dot product between two tensors.
```python exec="true" source="above" session="tensor" result="python"
a = Tensor([[1, 2], [3, 4]])
b = Tensor([[5, 6], [7, 8]])
print(a.dot(b).numpy())
```
"""
n1, n2 = len(self.shape), len(w.shape)
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
@@ -1587,12 +1687,37 @@ class Tensor:
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype))
def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor:
"""
Performs matrix multiplication between two tensors.
You can pass in the `reverse` keyword argument to control the order of the matrix multiplication.
You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation.
```python exec="true" source="above" session="tensor" result="python"
a = Tensor([[1, 2], [3, 4]])
b = Tensor([[5, 6], [7, 8]])
print(a.matmul(b).numpy())
```
"""
return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype)
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0)
return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1)
def cumsum(self, axis:int=0) -> Tensor:
"""
Computes the cumulative sum of the tensor along the specified axis.
You can pass in the `axis` keyword argument to control the axis along which the cumulative sum is computed.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.ones(2, 3)
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.cumsum(1).numpy())
```
"""
# TODO: someday the optimizer will find this on it's own
# for now this is a two stage cumsum
SPLIT = 256
@@ -1609,19 +1734,107 @@ class Tensor:
assert all_int((r,c)), "does not support symbolic"
if r == 0: return Tensor.zeros((r, c), **kwargs)
return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
def triu(self, k:int=0) -> Tensor:
"""
Returns the upper triangular part of the tensor, the other elements are set to 0.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2, 3], [4, 5, 6]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.triu(k=1).numpy())
```
"""
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0)
def tril(self, k:int=0) -> Tensor:
"""
Returns the lower triangular part of the tensor, the other elements are set to 0.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2, 3], [4, 5, 6]])
print(t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.tril().numpy())
```
"""
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self)
# ***** unary ops *****
def logical_not(self): return F.Eq.apply(*self._broadcasted(False))
def neg(self): return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
def contiguous(self): return F.Contiguous.apply(self)
def contiguous_backward(self): return F.ContiguousBackward.apply(self)
def log(self): return F.Log.apply(self.cast(least_upper_float(self.dtype)))
def log2(self): return self.log()/math.log(2)
def exp(self): return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
def exp2(self): return F.Exp.apply(self*math.log(2))
def logical_not(self):
"""
Computes the logical NOT of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([False, True]).logical_not().numpy())
```
"""
return F.Eq.apply(*self._broadcasted(False))
def neg(self):
"""
Negates the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).neg().numpy())
```
"""
return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
def contiguous(self):
"""
Returns a contiguous tensor.
"""
return F.Contiguous.apply(self)
def contiguous_backward(self):
"""
Inserts a contiguous operation in the backward pass.
"""
return F.ContiguousBackward.apply(self)
def log(self):
"""
Computes the natural logarithm element-wise.
See: https://en.wikipedia.org/wiki/Logarithm
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 4., 8.]).log().numpy())
```
"""
return F.Log.apply(self.cast(least_upper_float(self.dtype)))
def log2(self):
"""
Computes the base-2 logarithm element-wise.
See: https://en.wikipedia.org/wiki/Logarithm
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 4., 8.]).log2().numpy())
```
"""
return self.log()/math.log(2)
def exp(self):
"""
Computes the exponential function element-wise.
See: https://en.wikipedia.org/wiki/Exponential_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., 1., 2., 3.]).exp().numpy())
```
"""
return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
def exp2(self):
"""
Computes the base-2 exponential function element-wise.
See: https://en.wikipedia.org/wiki/Exponential_function
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
```
"""
return F.Exp.apply(self*math.log(2))
def relu(self):
"""
Applies the Rectified Linear Unit (ReLU) function element-wise.
@@ -1644,25 +1857,145 @@ class Tensor:
```
"""
return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
def sin(self): return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
def sqrt(self): return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
def rsqrt(self): return self.reciprocal().sqrt()
def cos(self): return ((math.pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
def sqrt(self):
"""
Computes the square root of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3.]).sqrt().numpy())
```
"""
return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
def rsqrt(self):
"""
Computes the reciprocal of the square root of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3.]).rsqrt().numpy())
```
"""
return self.reciprocal().sqrt()
def sin(self):
"""
Computes the sine of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
```
"""
return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
def cos(self):
"""
Computes the cosine of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).cos().numpy())
```
"""
return ((math.pi/2)-self).sin()
def tan(self):
"""
Computes the tangent of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([0., math.pi/4, math.pi/2, 3*math.pi/4, math.pi]).tan().numpy())
```
"""
return self.sin() / self.cos()
# ***** math functions *****
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).cast(self.dtype)
def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b)
def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b)
def trunc(self: Tensor) -> Tensor:
"""
Truncates the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).trunc().numpy())
```
"""
return self.cast(dtypes.int32).cast(self.dtype)
def ceil(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise towards positive infinity.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).ceil().numpy())
```
"""
return (self > (b := self.trunc())).where(b+1, b)
def floor(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise towards negative infinity.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).floor().numpy())
```
"""
return (self < (b := self.trunc())).where(b-1, b)
def round(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3.9, -2.1, -1.5, 0.5, 1.5, 2.1, 3.9]).round().numpy())
```
"""
return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor())
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor: return self + (end - self) * weight
def square(self): return self*self
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
def sign(self): return F.Sign.apply(self)
def abs(self): return self * self.sign()
def reciprocal(self): return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
def lerp(self, end: Tensor, weight: Union[Tensor, float]) -> Tensor:
"""
Linearly interpolates between `self` and `end` by `weight`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3.]).lerp(Tensor([4., 5., 6.]), 0.5).numpy())
```
"""
return self + (end - self) * weight
def square(self):
"""
Convenience method for squaring the tensor element-wise.
Equivalent to `self*self`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).square().numpy())
```
"""
return self*self
def clip(self, min_, max_):
"""
Clips (limits) the values in the tensor between `min_` and `max_` element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).clip(-1, 1).numpy())
```
"""
return self.maximum(min_).minimum(max_)
def sign(self):
"""
Returns the sign of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
```
"""
return F.Sign.apply(self)
def abs(self):
"""
Computes the absolute value of the tensor element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).abs().numpy())
```
"""
return self * self.sign()
def reciprocal(self):
"""
Compute `1/x` element-wise.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([1., 2., 3.]).reciprocal().numpy())
```
"""
return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
# ***** activation functions *****
@@ -2191,16 +2524,67 @@ class Tensor:
# ***** functional nn ops *****
def linear(self, weight:Tensor, bias:Optional[Tensor]=None):
"""
Applies a linear transformation to `self` using `weight` and `bias`.
See: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[1, 2], [3, 4]])
weight = Tensor([[1, 2], [3, 4]])
bias = Tensor([1, 2])
print(t.linear(weight, bias).numpy())
```
"""
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
def sequential(self, ll:List[Callable[[Tensor], Tensor]]):
"""
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1, 2, 3])
print(t.sequential([lambda x: x * 2, lambda x: x + 1]).numpy())
```
"""
return functools.reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
"""
Applies Layer Normalization over a mini-batch of inputs.
- Described: https://paperswithcode.com/method/layer-normalization
- Paper: https://arxiv.org/abs/1607.06450v1
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.randn(8, 10, 16) * 2 + 8
print(t.mean().item(), t.std().item())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.layernorm()
print(t.mean().item(), t.std().item())
```
"""
y = (self - self.mean(axis, keepdim=True))
return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt())
def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor:
"""
Applies Batch Normalization over a mini-batch of inputs.
- Described: https://paperswithcode.com/method/batch-normalization
- Paper: https://arxiv.org/abs/1502.03167
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.randn(8, 4, 16, 16) * 2 + 8
print(t.mean().item(), t.std().item())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.batchnorm(None, None, t.mean(axis=(0,2,3)), t.var(axis=(0,2,3)).add(1e-5).rsqrt())
print(t.mean().item(), t.std().item())
```
"""
axis_ = argfix(axis)
shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape))
x = self - mean.reshape(shape)
@@ -2209,15 +2593,53 @@ class Tensor:
return (ret + bias.reshape(shape)) if bias is not None else ret
def dropout(self, p=0.5) -> Tensor:
"""
Applies dropout to `self`.
NOTE: dropout is only applied when `Tensor.training` is `True`.
- Described: https://paperswithcode.com/method/dropout
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
```python exec="true" source="above" session="tensor" result="python"
Tensor.manual_seed(42)
t = Tensor.randn(2, 2)
with Tensor.train():
print(t.dropout().numpy())
```
"""
if not Tensor.training or p == 0: return self
return self * (Tensor.rand(*self.shape, requires_grad=False, dtype=dtypes.default_float, device=self.device) >= p) * (1/(1.0 - p))
def one_hot(self, num_classes:int) -> Tensor:
"""
Converts `self` to a one-hot tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0, 1, 3, 3, 4])
print(t.one_hot(5).numpy())
```
"""
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
# NOTE: it works if key, value have symbolic shape
"""
Computes scaled dot-product attention.
`self` is the query tensor, `key` is the key tensor, and `value` is the value tensor.
NOTE: it also works when `key` and `value` have symbolic shape.
- Described: https://paperswithcode.com/method/scaled
- Paper: https://arxiv.org/abs/1706.03762v7
```python exec="true" source="above" session="tensor" result="python"
q = Tensor.randn(2, 4, 8)
k = Tensor.randn(2, 4, 8)
v = Tensor.randn(2, 4, 8)
print(q.scaled_dot_product_attention(k, v).numpy())
```
"""
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
@@ -2225,14 +2647,48 @@ class Tensor:
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
def binary_crossentropy(self, y:Tensor) -> Tensor:
"""
Computes the binary cross-entropy loss between `self` and `y`.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([0.1, 0.9, 0.2])
y = Tensor([0, 1, 0])
print(t.binary_crossentropy(y).item())
```
"""
return (-y*self.log() - (1-y)*(1-self).log()).mean()
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
"""
Computes the binary cross-entropy loss between `self` and `y` where `self` is logits.
See: https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, -3])
y = Tensor([0, 1, 0])
print(t.binary_crossentropy_logits(y).item())
```
"""
return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean()
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor:
"""
Computes the sparse categorical cross-entropy loss between `self` and `Y`.
NOTE: `self` is logits and `Y` is the target labels.
See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([[-1, 2, -3], [1, -2, 3]])
Y = Tensor([1, 2])
print(t.sparse_categorical_crossentropy(Y).item())
```
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
# NOTE: self is a logits input
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
@@ -2245,12 +2701,65 @@ class Tensor:
# hack for devices that don't support bfloat16
assert self.dtype == dtypes.bfloat16
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
def cast(self, dtype:DType) -> Tensor:
"""
Casts `self` to a new dtype.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2.5, 3], dtype=dtypes.float)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.int32)
print(t.dtype, t.numpy())
```
"""
return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
def bitcast(self, dtype:DType) -> Tensor:
"""
Bitcasts `self` to a new dtype of the same itemsize.
`self` must not require a gradient.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.bitcast(dtypes.uint32)
print(t.dtype, t.numpy())
```
"""
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
def float(self) -> Tensor: return self.cast(dtypes.float32)
def half(self) -> Tensor: return self.cast(dtypes.float16)
def float(self) -> Tensor:
"""
Convenience method to cast `self` to a `float32` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.float()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.float32)
def half(self) -> Tensor:
"""
Convenience method to cast `self` to a `float16` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 2, 3], dtype=dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.half()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.float16)
# ***** convenience stuff *****