mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 07:05:04 -05:00
rename mlops to function (#4003)
This commit is contained in:
@@ -38,7 +38,7 @@ class Function:
|
||||
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
||||
return ret
|
||||
|
||||
import tinygrad.mlops as mlops
|
||||
import tinygrad.function as F
|
||||
|
||||
def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Tuple[LazyBuffer, ...]=()):
|
||||
if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src)
|
||||
@@ -370,19 +370,19 @@ class Tensor:
|
||||
def reshape(self, shape, *args) -> Tensor:
|
||||
new_shape = argfix(shape, *args)
|
||||
new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])
|
||||
return mlops.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
||||
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
||||
def expand(self, shape, *args) -> Tensor:
|
||||
new_shape = tuple([x if x != -1 and x is not None else s for s,x in zip(self.shape, argfix(shape, *args))])
|
||||
return mlops.Expand.apply(self, shape=new_shape) if new_shape != self.shape else self
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
return F.Expand.apply(self, shape=new_shape) if new_shape != self.shape else self
|
||||
def permute(self, order, *args) -> Tensor: return F.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return F.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
|
||||
def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor:
|
||||
if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self
|
||||
return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
|
||||
return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape)))
|
||||
def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor:
|
||||
if all(x is None or x == (0,0) for x in arg): return self
|
||||
ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
|
||||
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
|
||||
ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg)))
|
||||
return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value)
|
||||
|
||||
# ***** movement hlops *****
|
||||
|
||||
@@ -611,9 +611,9 @@ class Tensor:
|
||||
least_upper_dtype(self.dtype, dtypes.float)
|
||||
# cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc
|
||||
output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype
|
||||
return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype)
|
||||
return self.cast(acc_dtype)._reduce(F.Sum, axis, keepdim).cast(output_dtype)
|
||||
|
||||
def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim)
|
||||
def max(self, axis=None, keepdim=False): return self._reduce(F.Max, axis, keepdim)
|
||||
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))
|
||||
|
||||
def mean(self, axis=None, keepdim=False):
|
||||
@@ -807,18 +807,18 @@ class Tensor:
|
||||
|
||||
# ***** mlops (unary) *****
|
||||
|
||||
def logical_not(self): return mlops.Eq.apply(*self._broadcasted(False))
|
||||
def neg(self): return mlops.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
|
||||
def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def logical_not(self): return F.Eq.apply(*self._broadcasted(False))
|
||||
def neg(self): return F.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not()
|
||||
def contiguous(self): return F.Contiguous.apply(self)
|
||||
def contiguous_backward(self): return F.ContiguousBackward.apply(self)
|
||||
def log(self): return F.Log.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def log2(self): return self.log()/math.log(2)
|
||||
def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def exp2(self): return mlops.Exp.apply(self*math.log(2))
|
||||
def relu(self): return mlops.Relu.apply(self)
|
||||
def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def exp(self): return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def exp2(self): return F.Exp.apply(self*math.log(2))
|
||||
def relu(self): return F.Relu.apply(self)
|
||||
def sigmoid(self): return F.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def sin(self): return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def sqrt(self): return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def rsqrt(self): return self.reciprocal().sqrt()
|
||||
def cos(self): return ((math.pi/2)-self).sin()
|
||||
def tan(self): return self.sin() / self.cos()
|
||||
@@ -835,14 +835,14 @@ class Tensor:
|
||||
def clip(self, min_, max_): return self.maximum(min_).minimum(max_)
|
||||
def abs(self): return self.relu() + (-self).relu()
|
||||
def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype)
|
||||
def reciprocal(self): return mlops.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def reciprocal(self): return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
|
||||
|
||||
# ***** activation functions (unary) *****
|
||||
|
||||
def elu(self, alpha=1.0): return self.relu() - alpha*(1-self.exp()).relu()
|
||||
def celu(self, alpha=1.0): return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
||||
def swish(self): return self * self.sigmoid()
|
||||
def silu(self): return self.swish() # The SiLU function is also known as the swish function.
|
||||
def silu(self): return self.swish() # The SiLU function is also known as the swish F.
|
||||
def relu6(self): return self.relu() - (self-6).relu()
|
||||
def hardswish(self): return self * (self+3).relu6() * (1/6)
|
||||
def tanh(self): return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
||||
@@ -890,28 +890,28 @@ class Tensor:
|
||||
|
||||
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
return mlops.Add.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else self
|
||||
return F.Add.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else self
|
||||
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
return mlops.Sub.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else (-self if reverse else self)
|
||||
return F.Sub.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else (-self if reverse else self)
|
||||
def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
if not isinstance(x, Tensor) and x == 0.0: return mlops.Zero.apply(self)
|
||||
if not isinstance(x, Tensor) and x == 0.0: return F.Zero.apply(self)
|
||||
if not isinstance(x, Tensor) and x == -1.0: return -self
|
||||
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x != 1.0 else self
|
||||
return F.Mul.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
if not isinstance(x, Tensor) and not reverse and x != 0 and upcast: return self.mul(1/x)
|
||||
if (isinstance(x, Tensor) and dtypes.is_float(x.dtype)) or not upcast: return mlops.Div.apply(*self._broadcasted(x, reverse))
|
||||
return mlops.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse))
|
||||
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
|
||||
if (isinstance(x, Tensor) and dtypes.is_float(x.dtype)) or not upcast: return F.Div.apply(*self._broadcasted(x, reverse))
|
||||
return F.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse))
|
||||
def xor(self, x:Tensor, reverse=False) -> Tensor: return F.Xor.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
if not isinstance(x, Tensor) and not reverse:
|
||||
# simple pow identities
|
||||
if x < 0: return self.reciprocal().pow(-x)
|
||||
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1)
|
||||
if x in [3,2,1,0]: return functools.reduce(lambda acc,_: acc * self, range(int(x)), F.Zero.apply(self)+1)
|
||||
if x == 0.5: return self.sqrt()
|
||||
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp()
|
||||
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp()
|
||||
@@ -936,7 +936,7 @@ class Tensor:
|
||||
elif isinstance(other, Tensor): other, input_ = other._broadcasted(input_)
|
||||
x_,y = self._broadcasted(input_, match_dtype=False)
|
||||
x,z = x_._broadcasted(other, match_dtype=False)
|
||||
return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
|
||||
return F.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
|
||||
|
||||
# ***** op wrappers (wasted lines to make the typechecker happy) *****
|
||||
|
||||
@@ -966,11 +966,11 @@ class Tensor:
|
||||
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
||||
def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x))
|
||||
|
||||
def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False))
|
||||
def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True))
|
||||
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
|
||||
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
|
||||
def __ge__(self, x) -> Tensor: return (self<x).logical_not()
|
||||
def __le__(self, x) -> Tensor: return (self>x).logical_not()
|
||||
def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
|
||||
def __eq__(self, x) -> Tensor: return F.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override]
|
||||
def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override]
|
||||
|
||||
# ***** functional nn ops *****
|
||||
@@ -1030,10 +1030,10 @@ class Tensor:
|
||||
# hack for devices that don't support bfloat16
|
||||
assert self.dtype == dtypes.bfloat16
|
||||
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
||||
def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else mlops.Cast.apply(self, dtype=dtype)
|
||||
def cast(self, dtype:DType) -> Tensor: return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
|
||||
def bitcast(self, dtype:DType) -> Tensor:
|
||||
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
||||
return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
||||
return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
||||
def float(self) -> Tensor: return self.cast(dtypes.float32)
|
||||
def half(self) -> Tensor: return self.cast(dtypes.float16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user