use Scalar = Union[float, int, bool] in tensor.py (#3021)

unify the type spec for Tensor creation functions and broadcasted elementwise ops that take python scalar
This commit is contained in:
chenyu
2024-01-05 13:56:26 -05:00
committed by GitHub
parent 60abc62a3f
commit eda43767de

View File

@@ -1,7 +1,7 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, DefaultDict, cast
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, DefaultDict, cast, get_args
from collections import defaultdict
from functools import partialmethod, reduce
import numpy as np
@@ -40,6 +40,8 @@ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
Scalar = Union[float, int, bool]
class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
@@ -50,7 +52,7 @@ class Tensor:
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
no_grad: ClassVar[bool] = False
def __init__(self, data:Union[None, bool, int, float, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer],
def __init__(self, data:Union[None, Scalar, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer],
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
@@ -64,7 +66,7 @@ class Tensor:
# internal variables used for autograd graph construction
self._ctx: Optional[Function] = None
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, (bool, int, float)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, get_args(Scalar)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
elif isinstance(data, list):
@@ -131,7 +133,7 @@ class Tensor:
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
# TODO: these are good places to start removing numpy
def item(self) -> Union[float, int, bool]:
def item(self) -> Scalar:
assert self.numel() == 1, "must have one element for item"
return cast(Buffer, self.contiguous().realize().lazydata.base.realized).toCPU().item()
def data(self) -> memoryview: return self.numpy().data
@@ -184,7 +186,7 @@ class Tensor:
# ***** creation helper functions *****
@staticmethod
def full(shape:Tuple[sint, ...], fill_value: Union[bool, int, float], **kwargs):
def full(shape:Tuple[sint, ...], fill_value:Scalar, **kwargs):
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
@staticmethod
@@ -203,7 +205,7 @@ class Tensor:
def eye(dim:int, **kwargs):
return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
def full_like(self, fill_value: Union[bool, int, float], **kwargs):
def full_like(self, fill_value:Scalar, **kwargs):
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
@@ -680,6 +682,8 @@ class Tensor:
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype))
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
def cumsum(self, axis:int=0) -> Tensor:
@@ -753,9 +757,9 @@ class Tensor:
def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
def softsign(self): return self / (1 + self.abs())
# ***** broadcasted binary mlops *****
# ***** broadcasted elementwise mlops *****
def _broadcasted(self, y:Union[Tensor, float, int, bool], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
def _broadcasted(self, y:Union[Tensor, Scalar], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
x: Tensor = self
if not isinstance(y, Tensor):
# make y a Tensor
@@ -777,25 +781,26 @@ class Tensor:
broadcasted_shape = tuple(max(xi, yi) for xi, yi in zip(x.shape, y.shape))
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
def _to_const_val(self, x:Union[Tensor, float, int, bool]) -> Union[Tensor, float, int, bool]:
def _to_const_val(self, x:Union[Tensor, Scalar]) -> Union[Tensor, Scalar]:
# TODO: update with multi
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor:
def add(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
def sub(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor:
def sub(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self)
def mul(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor:
def mul(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self)
if x.__class__ is not Tensor and x == -1.0: return -self
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
def div(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor:
def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501
def pow(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor:
def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
@@ -814,13 +819,13 @@ class Tensor:
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign # noqa: E501
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
def maximum(self, x:Union[Tensor, float]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
# TODO: this implicitly changes dtype with /2
def maximum(self, x:Union[Tensor, Scalar]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
def minimum(self, x:Union[Tensor, Scalar]) -> Tensor: return -((-self).maximum(-x))
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]):
x_,y = self._broadcasted(input_, match_dtype=False)
x,z = x_._broadcasted(other, match_dtype=False)
return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))