mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
simple where broadcast (#1643)
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort, DType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
import math
|
||||
|
||||
class Contiguous(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time
|
||||
import time, math
|
||||
from functools import partialmethod, reduce
|
||||
from itertools import accumulate
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
|
||||
import math
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
|
||||
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
@@ -569,27 +568,27 @@ class Tensor:
|
||||
|
||||
# ***** broadcasted binary mlops *****
|
||||
|
||||
def _broadcasted(self, fxn:Type[Function], y:Union[Tensor, float], reverse:bool=False) -> Tensor:
|
||||
def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
if not isinstance(y, Tensor):
|
||||
y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32)
|
||||
if reverse: x, y = y, x
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return fxn.apply(x, y)
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
|
||||
|
||||
shape_delta = len(xshape) - len(yshape)
|
||||
if shape_delta > 0: y = y.reshape((1,) * shape_delta + yshape)
|
||||
elif shape_delta < 0: x = x.reshape((1,) * -shape_delta + xshape)
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return fxn.apply(x, y)
|
||||
if (xshape:=x.shape) == (yshape:=y.shape): return (x, y)
|
||||
|
||||
shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)])
|
||||
if xshape != shape_ret: x = x.expand(shape_ret)
|
||||
if yshape != shape_ret: y = y.expand(shape_ret)
|
||||
return fxn.apply(x, y)
|
||||
return (x, y)
|
||||
|
||||
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if x.__class__ is Tensor or x else self
|
||||
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if x.__class__ is Tensor or x or reverse else self
|
||||
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if x.__class__ is Tensor or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
|
||||
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
|
||||
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x or reverse else self
|
||||
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
|
||||
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
if x.__class__ is not Tensor and not reverse:
|
||||
# simple pow identities
|
||||
@@ -614,30 +613,10 @@ class Tensor:
|
||||
def maximum(self, x:Union[Tensor, float]) -> Tensor: return (self<x).detach().where(x, (self>x).detach().where(self, (self+x)/2))
|
||||
def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x))
|
||||
|
||||
# ***** broadcasted trinary mlops *****
|
||||
|
||||
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
|
||||
# TODO: ensure self is non-differentiable, could mess with ceil/float though
|
||||
dtype = self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32
|
||||
x: Tensor = self
|
||||
y: Tensor = Tensor(cast(float, input_), device=self.device, requires_grad=False, dtype=dtype) if input_.__class__ is not Tensor else cast(Tensor, input_)
|
||||
z: Tensor = Tensor(cast(float, other), device=self.device, requires_grad=False, dtype=dtype) if other.__class__ is not Tensor else cast(Tensor, other)
|
||||
if x.shape == y.shape and y.shape == z.shape: return mlops.Where.apply(x, y, z)
|
||||
|
||||
# TODO refactor this code along with the binary version above
|
||||
len_x_shape, len_y_shape, len_z_shape = len(x.shape), len(y.shape), len(z.shape)
|
||||
max_shape = max(len_x_shape, len_y_shape, len_z_shape)
|
||||
|
||||
if len_x_shape != max_shape: x = x.reshape((1,) * (max_shape - len_x_shape) + x.shape)
|
||||
if len_y_shape != max_shape: y = y.reshape((1,) * (max_shape - len_y_shape) + y.shape)
|
||||
if len_z_shape != max_shape: z = z.reshape((1,) * (max_shape - len_z_shape) + z.shape)
|
||||
|
||||
shape_ret = tuple([max(x, y, z) for x, y, z in zip(x.shape, y.shape, z.shape)])
|
||||
if x.shape != shape_ret: x = x.expand(shape_ret)
|
||||
if y.shape != shape_ret: y = y.expand(shape_ret)
|
||||
if z.shape != shape_ret: z = z.expand(shape_ret)
|
||||
|
||||
return mlops.Where.apply(x, y, z)
|
||||
x_,y = self._broadcasted(input_)
|
||||
x,z = x_._broadcasted(other)
|
||||
return mlops.Where.apply(x, *y._broadcasted(z))
|
||||
|
||||
# ***** binary op wrappers (18 wasted lines to make the typechecker happy) *****
|
||||
|
||||
@@ -663,8 +642,8 @@ class Tensor:
|
||||
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
||||
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
||||
|
||||
def __lt__(self, x) -> Tensor: return self._broadcasted(mlops.Less, x, False)
|
||||
def __gt__(self, x) -> Tensor: return self._broadcasted(mlops.Less, x, True)
|
||||
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
|
||||
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
|
||||
def __ge__(self, x) -> Tensor: return 1.0-(self<x)
|
||||
def __le__(self, x) -> Tensor: return 1.0-(self>x)
|
||||
def __ne__(self, x) -> Tensor: return (self<x) + (self>x) # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user