From eda43767de9e6436d6a4b6056e25c84beea44410 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 5 Jan 2024 13:56:26 -0500 Subject: [PATCH] use Scalar = Union[float, int, bool] in tensor.py (#3021) unify the type spec for Tensor creation functions and broadcasted elementwise ops that take python scalar --- tinygrad/tensor.py | 41 +++++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1c716394c0..5563b47bab 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,7 +1,7 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations import time, math, itertools -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, DefaultDict, cast +from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, DefaultDict, cast, get_args from collections import defaultdict from functools import partialmethod, reduce import numpy as np @@ -40,6 +40,8 @@ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src) return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None) +Scalar = Union[float, int, bool] + class Tensor: __slots__ = "lazydata", "requires_grad", "grad", "_ctx" __deletable__ = ('_ctx',) @@ -50,7 +52,7 @@ class Tensor: def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev no_grad: ClassVar[bool] = False - def __init__(self, data:Union[None, bool, int, float, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer], + def __init__(self, data:Union[None, Scalar, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer], device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) @@ -64,7 +66,7 @@ class Tensor: # internal variables used for autograd graph construction self._ctx: Optional[Function] = None if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" - elif isinstance(data, (bool, int, float)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) + elif isinstance(data, get_args(Scalar)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8)) elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device) elif isinstance(data, list): @@ -131,7 +133,7 @@ class Tensor: def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False) # TODO: these are good places to start removing numpy - def item(self) -> Union[float, int, bool]: + def item(self) -> Scalar: assert self.numel() == 1, "must have one element for item" return cast(Buffer, self.contiguous().realize().lazydata.base.realized).toCPU().item() def data(self) -> memoryview: return self.numpy().data @@ -184,7 +186,7 @@ class Tensor: # ***** creation helper functions ***** @staticmethod - def full(shape:Tuple[sint, ...], fill_value: Union[bool, int, float], **kwargs): + def full(shape:Tuple[sint, ...], fill_value:Scalar, **kwargs): return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod @@ -203,7 +205,7 @@ class Tensor: def eye(dim:int, **kwargs): return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim) - def full_like(self, fill_value: Union[bool, int, float], **kwargs): + def full_like(self, fill_value:Scalar, **kwargs): return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) def zeros_like(self, **kwargs): return self.full_like(0, **kwargs) def ones_like(self, **kwargs): return self.full_like(1, **kwargs) @@ -680,6 +682,8 @@ class Tensor: w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2)) return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype)) + def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) + def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) def cumsum(self, axis:int=0) -> Tensor: @@ -753,9 +757,9 @@ class Tensor: def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log() def softsign(self): return self / (1 + self.abs()) - # ***** broadcasted binary mlops ***** + # ***** broadcasted elementwise mlops ***** - def _broadcasted(self, y:Union[Tensor, float, int, bool], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]: + def _broadcasted(self, y:Union[Tensor, Scalar], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]: x: Tensor = self if not isinstance(y, Tensor): # make y a Tensor @@ -777,25 +781,26 @@ class Tensor: broadcasted_shape = tuple(max(xi, yi) for xi, yi in zip(x.shape, y.shape)) return x.expand(broadcasted_shape), y.expand(broadcasted_shape) - def _to_const_val(self, x:Union[Tensor, float, int, bool]) -> Union[Tensor, float, int, bool]: + def _to_const_val(self, x:Union[Tensor, Scalar]) -> Union[Tensor, Scalar]: + # TODO: update with multi return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \ and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x - def add(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor: + def add(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: x = self._to_const_val(x) return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self - def sub(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor: + def sub(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: x = self._to_const_val(x) return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self) - def mul(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor: + def mul(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: x = self._to_const_val(x) if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self) if x.__class__ is not Tensor and x == -1.0: return -self return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self - def div(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor: + def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: x = self._to_const_val(x) return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501 - def pow(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor: + def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: x = self._to_const_val(x) if not isinstance(x, Tensor) and not reverse: # simple pow identities @@ -814,13 +819,13 @@ class Tensor: to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign # noqa: E501 inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan") return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan) - def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse)) - def maximum(self, x:Union[Tensor, float]) -> Tensor: return (selfx).detach().where(self, (self+x)/2)) - def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x)) + # TODO: this implicitly changes dtype with /2 + def maximum(self, x:Union[Tensor, Scalar]) -> Tensor: return (selfx).detach().where(self, (self+x)/2)) + def minimum(self, x:Union[Tensor, Scalar]) -> Tensor: return -((-self).maximum(-x)) - def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]): + def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]): x_,y = self._broadcasted(input_, match_dtype=False) x,z = x_._broadcasted(other, match_dtype=False) return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))