Import whole math module in tensor.py (#1628)

This commit is contained in:
c143
2023-08-22 23:07:46 +02:00
committed by GitHub
parent 6fcfa50b35
commit c9c40bb16f

View File

@@ -5,7 +5,7 @@ from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
from math import ceil, pi, prod, sqrt, log, cos, copysign
import math
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from tinygrad.lazy import Device, LazyBuffer
@@ -128,7 +128,7 @@ class Tensor:
return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
@staticmethod
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, math.prod(shape), **kwargs).reshape(shape)
_seed: int = int(time.time())
@staticmethod
@@ -137,7 +137,7 @@ class Tensor:
@staticmethod
def rand(*shape, **kwargs):
Tensor._seed += 1
return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
return Tensor._loadop(LoadOps.RAND, math.prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
# ***** creation helper functions *****
@@ -153,7 +153,7 @@ class Tensor:
@staticmethod
def arange(start, stop=None, step=1, **kwargs):
if stop is None: stop, start = start, 0
return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
@staticmethod
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
@@ -174,7 +174,7 @@ class Tensor:
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand(2, *shape, **kwargs)
return src[0].mul(2*pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
@staticmethod
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
@@ -183,22 +183,22 @@ class Tensor:
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5)
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+math.prod(shape[1:])))**0.5)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
bound = sqrt(3.0) * sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
std = sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
# ***** toposort and backward pass *****
@@ -234,7 +234,7 @@ class Tensor:
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
return mlops.Reshape.apply(self, shape=tuple([-math.prod(self.shape) // math.prod(new_shape) if s == -1 else s for s in new_shape]))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
@@ -382,7 +382,7 @@ class Tensor:
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
def chunk(self, num:int, dim:int) -> List[Tensor]:
dim, step = dim + self.ndim if dim < 0 else dim, ceil(self.shape[dim]/num)
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
return [self[tuple(sl)] for sl in slice_params]
@@ -425,10 +425,10 @@ class Tensor:
def mean(self, axis=None, keepdim=False):
out = self.sum(axis=axis, keepdim=keepdim)
return out * (prod(out.shape)/prod(self.shape))
return out * (math.prod(out.shape)/math.prod(self.shape))
def std(self, axis=None, keepdim=False, correction=1):
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt()
return (square_sum / (math.prod(self.shape)/math.prod(square_sum.shape)-correction)).sqrt()
def _softmax(self, axis):
m = self - self.max(axis=axis, keepdim=True)
e = m.exp()
@@ -444,8 +444,8 @@ class Tensor:
def argmax(self, axis=None, keepdim=False):
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1).reshape(self.shape)
return prod(self.shape) - idx.max() - 1
idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1).reshape(self.shape)
return math.prod(self.shape) - idx.max() - 1
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
@@ -461,7 +461,7 @@ class Tensor:
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
e_ = [ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)])
# slide by dilation
xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
@@ -524,14 +524,14 @@ class Tensor:
def contiguous(self): return mlops.Contiguous.apply(self)
def log(self): return mlops.Log.apply(self)
def log2(self): return mlops.Log.apply(self)/log(2)
def log2(self): return mlops.Log.apply(self)/math.log(2)
def exp(self): return mlops.Exp.apply(self)
def relu(self): return mlops.Relu.apply(self)
def sigmoid(self): return mlops.Sigmoid.apply(self)
def sin(self): return mlops.Sin.apply(self)
def sqrt(self): return mlops.Sqrt.apply(self)
def rsqrt(self): return (1/self).sqrt()
def cos(self): return ((pi/2)-self).sin()
def cos(self): return ((math.pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
@staticmethod
@@ -597,12 +597,12 @@ class Tensor:
if x == 2.0: return self*self
if x == 1.0: return self
if x == 0.5: return self.sqrt()
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(log(x)).exp()
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(log(abs(x))).exp()
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
sign = (x * pi).cos() if isinstance(x, Tensor) else cos(x * pi) if not reverse else (self * pi).cos()
sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()
# we only need to correct the sign if the base is negative
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else copysign(1, x)) - 1) / -2
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
# we need 0 to be positive so we need to correct base_sign when the base is 0
base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
# inject nan if the base is negative and the power is not an integer
@@ -696,7 +696,7 @@ class Tensor:
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
return (self @ key.transpose(-2,-1) / sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
loss_mask = Y != ignore_index
@@ -714,7 +714,7 @@ class Tensor:
# ***** Convenience stuff *****
@property
def ndim(self) -> int: return len(self.shape)
def numel(self) -> int: return prod(self.shape)
def numel(self) -> int: return math.prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)