diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c7ced60a37..be46d2543a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -5,7 +5,7 @@ from functools import partialmethod, reduce from itertools import accumulate import numpy as np from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast -from math import ceil, pi, prod, sqrt, log, cos, copysign +import math from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes from tinygrad.lazy import Device, LazyBuffer @@ -128,7 +128,7 @@ class Tensor: return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) @staticmethod - def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape) + def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, math.prod(shape), **kwargs).reshape(shape) _seed: int = int(time.time()) @staticmethod @@ -137,7 +137,7 @@ class Tensor: @staticmethod def rand(*shape, **kwargs): Tensor._seed += 1 - return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape) + return Tensor._loadop(LoadOps.RAND, math.prod(shape), arg=Tensor._seed, **kwargs).reshape(shape) # ***** creation helper functions ***** @@ -153,7 +153,7 @@ class Tensor: @staticmethod def arange(start, stop=None, step=1, **kwargs): if stop is None: stop, start = start, 0 - return Tensor.full((ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step) + return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step) @staticmethod def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs): @@ -174,7 +174,7 @@ class Tensor: def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor: # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform src = Tensor.rand(2, *shape, **kwargs) - return src[0].mul(2*pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype) + return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype) @staticmethod def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean @@ -183,22 +183,22 @@ class Tensor: def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low @staticmethod - def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5) + def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5) # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform @staticmethod - def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5) + def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+math.prod(shape[1:])))**0.5) # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_ @staticmethod def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor: - bound = sqrt(3.0) * sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:])) + bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:])) return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs) # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_ @staticmethod def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor: - std = sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:])) + std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:])) return Tensor.normal(*shape, mean=0.0, std=std, **kwargs) # ***** toposort and backward pass ***** @@ -234,7 +234,7 @@ class Tensor: def reshape(self, shape, *args) -> Tensor: new_shape = argfix(shape, *args) assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}" - return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])) + return mlops.Reshape.apply(self, shape=tuple([-math.prod(self.shape) // math.prod(new_shape) if s == -1 else s for s in new_shape])) def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))])) def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)]) @@ -382,7 +382,7 @@ class Tensor: return self.reshape(new_shape).expand(expand_shape).reshape(final_shape) def chunk(self, num:int, dim:int) -> List[Tensor]: - dim, step = dim + self.ndim if dim < 0 else dim, ceil(self.shape[dim]/num) + dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num) slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)] return [self[tuple(sl)] for sl in slice_params] @@ -425,10 +425,10 @@ class Tensor: def mean(self, axis=None, keepdim=False): out = self.sum(axis=axis, keepdim=keepdim) - return out * (prod(out.shape)/prod(self.shape)) + return out * (math.prod(out.shape)/math.prod(self.shape)) def std(self, axis=None, keepdim=False, correction=1): square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) - return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt() + return (square_sum / (math.prod(self.shape)/math.prod(square_sum.shape)-correction)).sqrt() def _softmax(self, axis): m = self - self.max(axis=axis, keepdim=True) e = m.exp() @@ -444,8 +444,8 @@ class Tensor: def argmax(self, axis=None, keepdim=False): if axis is None: - idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1).reshape(self.shape) - return prod(self.shape) - idx.max() - 1 + idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1).reshape(self.shape) + return math.prod(self.shape) - idx.max() - 1 axis = axis + len(self.shape) if axis < 0 else axis m = self == self.max(axis=axis, keepdim=True) idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) @@ -461,7 +461,7 @@ class Tensor: slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):] if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_): o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)] - e_ = [ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding + e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)]) # slide by dilation xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]) @@ -524,14 +524,14 @@ class Tensor: def contiguous(self): return mlops.Contiguous.apply(self) def log(self): return mlops.Log.apply(self) - def log2(self): return mlops.Log.apply(self)/log(2) + def log2(self): return mlops.Log.apply(self)/math.log(2) def exp(self): return mlops.Exp.apply(self) def relu(self): return mlops.Relu.apply(self) def sigmoid(self): return mlops.Sigmoid.apply(self) def sin(self): return mlops.Sin.apply(self) def sqrt(self): return mlops.Sqrt.apply(self) def rsqrt(self): return (1/self).sqrt() - def cos(self): return ((pi/2)-self).sin() + def cos(self): return ((math.pi/2)-self).sin() def tan(self): return self.sin() / self.cos() @staticmethod @@ -597,12 +597,12 @@ class Tensor: if x == 2.0: return self*self if x == 1.0: return self if x == 0.5: return self.sqrt() - if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(log(x)).exp() - ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(log(abs(x))).exp() + if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp() + ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp() # correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power) - sign = (x * pi).cos() if isinstance(x, Tensor) else cos(x * pi) if not reverse else (self * pi).cos() + sign = (x * math.pi).cos() if isinstance(x, Tensor) else math.cos(x * math.pi) if not reverse else (self * math.pi).cos() # we only need to correct the sign if the base is negative - base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else copysign(1, x)) - 1) / -2 + base_sign = ((self.sign() if not reverse else x.sign() if isinstance(x, Tensor) else math.copysign(1, x)) - 1) / -2 # we need 0 to be positive so we need to correct base_sign when the base is 0 base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x)))))) # inject nan if the base is negative and the power is not an integer @@ -696,7 +696,7 @@ class Tensor: def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False).tril(0).cast(dtypes.bool) if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask) - return (self @ key.transpose(-2,-1) / sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value + return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor: loss_mask = Y != ignore_index @@ -714,7 +714,7 @@ class Tensor: # ***** Convenience stuff ***** @property def ndim(self) -> int: return len(self.shape) - def numel(self) -> int: return prod(self.shape) + def numel(self) -> int: return math.prod(self.shape) def element_size(self) -> int: return self.dtype.itemsize def nbytes(self) -> int: return self.numel() * self.element_size() def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)