mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use Scalar = Union[float, int, bool] in tensor.py (#3021)
unify the type spec for Tensor creation functions and broadcasted elementwise ops that take python scalar
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math, itertools
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, DefaultDict, cast
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, DefaultDict, cast, get_args
|
||||
from collections import defaultdict
|
||||
from functools import partialmethod, reduce
|
||||
import numpy as np
|
||||
@@ -40,6 +40,8 @@ def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str,
|
||||
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
|
||||
return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None)
|
||||
|
||||
Scalar = Union[float, int, bool]
|
||||
|
||||
class Tensor:
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
__deletable__ = ('_ctx',)
|
||||
@@ -50,7 +52,7 @@ class Tensor:
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev
|
||||
|
||||
no_grad: ClassVar[bool] = False
|
||||
def __init__(self, data:Union[None, bool, int, float, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer],
|
||||
def __init__(self, data:Union[None, Scalar, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer],
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
||||
@@ -64,7 +66,7 @@ class Tensor:
|
||||
# internal variables used for autograd graph construction
|
||||
self._ctx: Optional[Function] = None
|
||||
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
elif isinstance(data, (bool, int, float)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
||||
elif isinstance(data, get_args(Scalar)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
|
||||
elif isinstance(data, bytes): data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8))
|
||||
elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device)
|
||||
elif isinstance(data, list):
|
||||
@@ -131,7 +133,7 @@ class Tensor:
|
||||
def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
|
||||
# TODO: these are good places to start removing numpy
|
||||
def item(self) -> Union[float, int, bool]:
|
||||
def item(self) -> Scalar:
|
||||
assert self.numel() == 1, "must have one element for item"
|
||||
return cast(Buffer, self.contiguous().realize().lazydata.base.realized).toCPU().item()
|
||||
def data(self) -> memoryview: return self.numpy().data
|
||||
@@ -184,7 +186,7 @@ class Tensor:
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@staticmethod
|
||||
def full(shape:Tuple[sint, ...], fill_value: Union[bool, int, float], **kwargs):
|
||||
def full(shape:Tuple[sint, ...], fill_value:Scalar, **kwargs):
|
||||
return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape)
|
||||
|
||||
@staticmethod
|
||||
@@ -203,7 +205,7 @@ class Tensor:
|
||||
def eye(dim:int, **kwargs):
|
||||
return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim)
|
||||
|
||||
def full_like(self, fill_value: Union[bool, int, float], **kwargs):
|
||||
def full_like(self, fill_value:Scalar, **kwargs):
|
||||
return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs)
|
||||
def zeros_like(self, **kwargs): return self.full_like(0, **kwargs)
|
||||
def ones_like(self, **kwargs): return self.full_like(1, **kwargs)
|
||||
@@ -680,6 +682,8 @@ class Tensor:
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||
return (x*w).sum(-1).cast(least_upper_dtype(x.dtype, w.dtype))
|
||||
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
|
||||
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor:
|
||||
return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
def cumsum(self, axis:int=0) -> Tensor:
|
||||
@@ -753,9 +757,9 @@ class Tensor:
|
||||
def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log()
|
||||
def softsign(self): return self / (1 + self.abs())
|
||||
|
||||
# ***** broadcasted binary mlops *****
|
||||
# ***** broadcasted elementwise mlops *****
|
||||
|
||||
def _broadcasted(self, y:Union[Tensor, float, int, bool], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
||||
def _broadcasted(self, y:Union[Tensor, Scalar], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
if not isinstance(y, Tensor):
|
||||
# make y a Tensor
|
||||
@@ -777,25 +781,26 @@ class Tensor:
|
||||
broadcasted_shape = tuple(max(xi, yi) for xi, yi in zip(x.shape, y.shape))
|
||||
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
|
||||
|
||||
def _to_const_val(self, x:Union[Tensor, float, int, bool]) -> Union[Tensor, float, int, bool]:
|
||||
def _to_const_val(self, x:Union[Tensor, Scalar]) -> Union[Tensor, Scalar]:
|
||||
# TODO: update with multi
|
||||
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \
|
||||
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
||||
|
||||
def add(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor:
|
||||
def add(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
|
||||
def sub(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor:
|
||||
def sub(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self)
|
||||
def mul(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor:
|
||||
def mul(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self)
|
||||
if x.__class__ is not Tensor and x == -1.0: return -self
|
||||
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor:
|
||||
def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501
|
||||
def pow(self, x:Union[Tensor, float, int, bool], reverse=False) -> Tensor:
|
||||
def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
if not isinstance(x, Tensor) and not reverse:
|
||||
# simple pow identities
|
||||
@@ -814,13 +819,13 @@ class Tensor:
|
||||
to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign # noqa: E501
|
||||
inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan")
|
||||
return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan)
|
||||
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
|
||||
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
def maximum(self, x:Union[Tensor, float]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
|
||||
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
|
||||
# TODO: this implicitly changes dtype with /2
|
||||
def maximum(self, x:Union[Tensor, Scalar]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
|
||||
def minimum(self, x:Union[Tensor, Scalar]) -> Tensor: return -((-self).maximum(-x))
|
||||
|
||||
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
|
||||
def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]):
|
||||
x_,y = self._broadcasted(input_, match_dtype=False)
|
||||
x,z = x_._broadcasted(other, match_dtype=False)
|
||||
return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
|
||||
|
||||
Reference in New Issue
Block a user