simple where broadcast (#1643)

This commit is contained in:
George Hotz
2023-08-22 21:24:49 -07:00
committed by GitHub
parent c831218139
commit 41e83be3dd
2 changed files with 16 additions and 37 deletions

View File

@@ -1,9 +1,9 @@
import math
from typing import Tuple, Optional
from tinygrad.helpers import argsort, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
import math
class Contiguous(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()

View File

@@ -1,11 +1,10 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time
import time, math
from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
import math
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from tinygrad.lazy import Device, LazyBuffer
@@ -569,27 +568,27 @@ class Tensor:
# ***** broadcasted binary mlops *****
def _broadcasted(self, fxn:Type[Function], y:Union[Tensor, float], reverse:bool=False) -> Tensor:
def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]:
x: Tensor = self
if not isinstance(y, Tensor):
y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32)
if reverse: x, y = y, x
if (xshape:=x.shape) == (yshape:=y.shape): return fxn.apply(x, y)
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
shape_delta = len(xshape) - len(yshape)
if shape_delta > 0: y = y.reshape((1,) * shape_delta + yshape)
elif shape_delta < 0: x = x.reshape((1,) * -shape_delta + xshape)
if (xshape:=x.shape) == (yshape:=y.shape): return fxn.apply(x, y)
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)])
if xshape != shape_ret: x = x.expand(shape_ret)
if yshape != shape_ret: y = y.expand(shape_ret)
return fxn.apply(x, y)
return (x, y)
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if x.__class__ is Tensor or x else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if x.__class__ is Tensor or x or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if x.__class__ is Tensor or x != 1.0 else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
if x.__class__ is not Tensor and not reverse:
# simple pow identities
@@ -614,30 +613,10 @@ class Tensor:
def maximum(self, x:Union[Tensor, float]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
# ***** broadcasted trinary mlops *****
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
# TODO: ensure self is non-differentiable, could mess with ceil/float though
dtype = self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32
x: Tensor = self
y: Tensor = Tensor(cast(float, input_), device=self.device, requires_grad=False, dtype=dtype) if input_.__class__ is not Tensor else cast(Tensor, input_)
z: Tensor = Tensor(cast(float, other), device=self.device, requires_grad=False, dtype=dtype) if other.__class__ is not Tensor else cast(Tensor, other)
if x.shape == y.shape and y.shape == z.shape: return mlops.Where.apply(x, y, z)
# TODO refactor this code along with the binary version above
len_x_shape, len_y_shape, len_z_shape = len(x.shape), len(y.shape), len(z.shape)
max_shape = max(len_x_shape, len_y_shape, len_z_shape)
if len_x_shape != max_shape: x = x.reshape((1,) * (max_shape - len_x_shape) + x.shape)
if len_y_shape != max_shape: y = y.reshape((1,) * (max_shape - len_y_shape) + y.shape)
if len_z_shape != max_shape: z = z.reshape((1,) * (max_shape - len_z_shape) + z.shape)
shape_ret = tuple([max(x, y, z) for x, y, z in zip(x.shape, y.shape, z.shape)])
if x.shape != shape_ret: x = x.expand(shape_ret)
if y.shape != shape_ret: y = y.expand(shape_ret)
if z.shape != shape_ret: z = z.expand(shape_ret)
return mlops.Where.apply(x, y, z)
x_,y = self._broadcasted(input_)
x,z = x_._broadcasted(other)
return mlops.Where.apply(x, *y._broadcasted(z))
# ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
@@ -663,8 +642,8 @@ class Tensor:
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
def __lt__(self, x) -> Tensor: return self._broadcasted(mlops.Less, x, False)
def __gt__(self, x) -> Tensor: return self._broadcasted(mlops.Less, x, True)
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
def __ge__(self, x) -> Tensor: return 1.0-(self<x)
def __le__(self, x) -> Tensor: return 1.0-(self>x)
def __ne__(self, x) -> Tensor: return (self<x) + (self>x) # type: ignore