mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 02:21:40 -05:00
Import whole math module in tensor.py (#1628)
This commit is contained in:
@@ -5,7 +5,7 @@ from functools import partialmethod, reduce
|
||||
from itertools import accumulate
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
|
||||
from math import ceil, pi, prod, sqrt, log, cos, copysign
|
||||
import math
|
||||
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
@@ -128,7 +128,7 @@ class Tensor:
|
||||
return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)
|
||||
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, math.prod(shape), **kwargs).reshape(shape)
|
||||
|
||||
_seed: int = int(time.time())
|
||||
@staticmethod
|
||||
@@ -137,7 +137,7 @@ class Tensor:
|
||||
@staticmethod
|
||||
def rand(*shape, **kwargs):
|
||||
Tensor._seed += 1
|
||||
return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
|
||||
return Tensor._loadop(LoadOps.RAND, math.prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@@ -153,7 +153,7 @@ class Tensor:
|
||||
@staticmethod
|
||||
def arange(start, stop=None, step=1, **kwargs):
|
||||
if stop is None: stop, start = start, 0
|
||||
return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
|
||||
return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
|
||||
|
||||
@staticmethod
|
||||
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
|
||||
@@ -174,7 +174,7 @@ class Tensor:
|
||||
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
|
||||
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
|
||||
src = Tensor.rand(2, *shape, **kwargs)
|
||||
return src[0].mul(2*pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
|
||||
return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
|
||||
|
||||
@staticmethod
|
||||
def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean
|
||||
@@ -183,22 +183,22 @@ class Tensor:
|
||||
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
|
||||
|
||||
@staticmethod
|
||||
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)
|
||||
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5)
|
||||
|
||||
# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
|
||||
@staticmethod
|
||||
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)
|
||||
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+math.prod(shape[1:])))**0.5)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
|
||||
@staticmethod
|
||||
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
bound = sqrt(3.0) * sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
|
||||
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
|
||||
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
|
||||
|
||||
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
|
||||
@staticmethod
|
||||
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
|
||||
std = sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
|
||||
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
|
||||
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
@@ -234,7 +234,7 @@ class Tensor:
|
||||
def reshape(self, shape, *args) -> Tensor:
|
||||
new_shape = argfix(shape, *args)
|
||||
assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
|
||||
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
|
||||
return mlops.Reshape.apply(self, shape=tuple([-math.prod(self.shape) // math.prod(new_shape) if s == -1 else s for s in new_shape]))
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
@@ -382,7 +382,7 @@ class Tensor:
|
||||
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)
|
||||
|
||||
def chunk(self, num:int, dim:int) -> List[Tensor]:
|
||||
dim, step = dim + self.ndim if dim < 0 else dim, ceil(self.shape[dim]/num)
|
||||
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
|
||||
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
|
||||
return [self[tuple(sl)] for sl in slice_params]
|
||||
|
||||
@@ -425,10 +425,10 @@ class Tensor:
|
||||
|
||||
def mean(self, axis=None, keepdim=False):
|
||||
out = self.sum(axis=axis, keepdim=keepdim)
|
||||
return out * (prod(out.shape)/prod(self.shape))
|
||||
return out * (math.prod(out.shape)/math.prod(self.shape))
|
||||
def std(self, axis=None, keepdim=False, correction=1):
|
||||
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
|
||||
return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt()
|
||||
return (square_sum / (math.prod(self.shape)/math.prod(square_sum.shape)-correction)).sqrt()
|
||||
def _softmax(self, axis):
|
||||
m = self - self.max(axis=axis, keepdim=True)
|
||||
e = m.exp()
|
||||
@@ -444,8 +444,8 @@ class Tensor:
|
||||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
if axis is None:
|
||||
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1).reshape(self.shape)
|
||||
return prod(self.shape) - idx.max() - 1
|
||||
idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1).reshape(self.shape)
|
||||
return math.prod(self.shape) - idx.max() - 1
|
||||
axis = axis + len(self.shape) if axis < 0 else axis
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
@@ -461,7 +461,7 @@ class Tensor:
|
||||
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
|
||||
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
|
||||
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
|
||||
e_ = [ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
|
||||
e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
|
||||
xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)])
|
||||
# slide by dilation
|
||||
xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])
|
||||
@@ -524,14 +524,14 @@ class Tensor:
|
||||
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
def log(self): return mlops.Log.apply(self)
|
||||
def log2(self): return mlops.Log.apply(self)/log(2)
|
||||
def log2(self): return mlops.Log.apply(self)/math.log(2)
|
||||
def exp(self): return mlops.Exp.apply(self)
|
||||
def relu(self): return mlops.Relu.apply(self)
|
||||
def sigmoid(self): return mlops.Sigmoid.apply(self)
|
||||
def sin(self): return mlops.Sin.apply(self)
|
||||
def sqrt(self): return mlops.Sqrt.apply(self)
|
||||
def rsqrt(self): return (1/self).sqrt()
|
||||
def cos(self): return ((pi/2)-self).sin()
|
||||
def cos(self): return ((math.pi/2)-self).sin()
|
||||
def tan(self): return self.sin() / self.cos()
|
||||
|
||||
@staticmethod
|
||||
@@ -597,12 +597,12 @@ class Tensor:
|
||||
if x == 2.0: return self*self
|
||||
if x == 1.0: return self
|
||||
if x == 0.5: return self.sqrt()
|
||||
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(log(x)).exp()
|
||||
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(log(abs(x))).exp()
|
||||
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
||||
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
|
||||
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
|
||||
sign = (x * pi).cos() if isinstance(x, Tensor) else cos(x * pi) if not reverse else (self * pi).cos()
|
||||
sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos()
|
||||
# we only need to correct the sign if the base is negative
|
||||
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else copysign(1, x)) - 1) / -2
|
||||
base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2
|
||||
# we need 0 to be positive so we need to correct base_sign when the base is 0
|
||||
base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x))))))
|
||||
# inject nan if the base is negative and the power is not an integer
|
||||
@@ -696,7 +696,7 @@ class Tensor:
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
||||
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False).tril(0).cast(dtypes.bool)
|
||||
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
|
||||
return (self @ key.transpose(-2,-1) / sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
|
||||
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
|
||||
|
||||
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
|
||||
loss_mask = Y != ignore_index
|
||||
@@ -714,7 +714,7 @@ class Tensor:
|
||||
# ***** Convenience stuff *****
|
||||
@property
|
||||
def ndim(self) -> int: return len(self.shape)
|
||||
def numel(self) -> int: return prod(self.shape)
|
||||
def numel(self) -> int: return math.prod(self.shape)
|
||||
def element_size(self) -> int: return self.dtype.itemsize
|
||||
def nbytes(self) -> int: return self.numel() * self.element_size()
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
|
||||
Reference in New Issue
Block a user