From a6e496b1950732ebc127f33b5ded1bd08f309b1b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 26 Jan 2025 18:58:02 +0900 Subject: [PATCH] remove Function class [pr] (#8753) * remove Function class [pr] * actually remove function * fix docs --- docs/developer/developer.md | 2 +- docs/developer/function.md | 33 ----------- mkdocs.yml | 1 - tinygrad/function.py | 108 ---------------------------------- tinygrad/tensor.py | 112 +++++++++++++++++------------------- 5 files changed, 53 insertions(+), 203 deletions(-) delete mode 100644 docs/developer/function.md delete mode 100644 tinygrad/function.py diff --git a/docs/developer/developer.md b/docs/developer/developer.md index b40d715af0..39e9e0901b 100644 --- a/docs/developer/developer.md +++ b/docs/developer/developer.md @@ -9,7 +9,7 @@ There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-not ## Frontend -Everything in [Tensor](../tensor/index.md) is syntactic sugar around [function.py](function.md), where the forwards and backwards passes are implemented for the different functions. There's about 25 of them, implemented using about 20 basic ops. Those basic ops go on to construct a graph of [UOps](../developer/uop.md). +Everything in [Tensor](../tensor/index.md) is syntactic sugar around constructing a graph of [UOps](../developer/uop.md). The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not all UOps will actually become realized. There's two types of UOps, base and view. base contains compute into a contiguous buffer, and view is a view (specified by a ShapeTracker). Inputs to a base can be either base or view, inputs to a view can only be a single base. diff --git a/docs/developer/function.md b/docs/developer/function.md deleted file mode 100644 index 9f1b85f8cd..0000000000 --- a/docs/developer/function.md +++ /dev/null @@ -1,33 +0,0 @@ -::: tinygrad.function - options: - members: [ - "Contiguous", - "ContiguousBackward", - "Cast", - "Neg", - "Reciprocal", - "Sin", - "Relu", - "Log", - "Exp", - "Sqrt", - "Sigmoid", - "Sign", - "Less", - "Eq", - "Xor", - "Add", - "Sub", - "Mul", - "Div", - "Where", - "Sum", - "Max", - "Expand", - "Reshape", - "Permute", - "Pad", - "Shrink", - "Flip", - ] - show_source: false diff --git a/mkdocs.yml b/mkdocs.yml index 291998dac5..38419a5708 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,7 +22,6 @@ nav: - Runtime: runtime.md - Developer: - Intro: developer/developer.md - - Function (autodiff): developer/function.md - UOp: developer/uop.md - Runtime: - developer/runtime.md diff --git a/tinygrad/function.py b/tinygrad/function.py deleted file mode 100644 index af5cecb8eb..0000000000 --- a/tinygrad/function.py +++ /dev/null @@ -1,108 +0,0 @@ -"""This is where the forwards and backwards passes live.""" -import math -from tinygrad.dtype import DType -from tinygrad.ops import Ops, sint, UOp -from tinygrad.tensor import Function - -class Contiguous(Function): - def forward(self, x:UOp) -> UOp: return x.contiguous() - -class ContiguousBackward(Function): - def forward(self, x:UOp) -> UOp: return x.contiguous_backward() - -class Cast(Function): - def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp: return x.bitcast(dtype) if bitcast else x.cast(dtype) - -# ************* unary ops ************* - -class Reciprocal(Function): - def forward(self, x:UOp) -> UOp: return x.reciprocal() - -class Sin(Function): - def forward(self, x:UOp) -> UOp: return x.sin() - -class Relu(Function): - def forward(self, x:UOp) -> UOp: return (x>0).where(x, 0) - -class Log(Function): - def forward(self, x:UOp) -> UOp: return x.log2() * math.log(2) - -class Exp(Function): - def forward(self, x:UOp) -> UOp: return (x * (1/math.log(2))).exp2() - -class Sqrt(Function): - def forward(self, x:UOp) -> UOp: return x.sqrt() - -class Sign(Function): - # NOTE: the x*0 is to match torch behavior without function.py - def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0 - -# ************* binary ops ************* - -class Less(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x UOp: return x.ne(y) - -class Xor(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x^y - -class BitwiseAnd(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x&y - -class BitwiseOr(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x|y - -class Threefry(Function): - def forward(self, x:UOp, seed:UOp) -> UOp: return x.threefry(seed) - -class Add(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x+y - -class Mul(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x * y - -class IDiv(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x // y - -class Mod(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x % y - -# ************* ternary ops ************* - -class Where(Function): - def forward(self, x:UOp, y:UOp, z:UOp) -> UOp: return x.where(y, z) - - -# ************* reduce ops ************* - -class Sum(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.ADD, axis) - -class Prod(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.MUL, axis) - -class Max(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.MAX, axis) - -# ************* movement ops ************* - -# NOTE: this is sum in reverse -class Expand(Function): - def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: return x.expand(shape) - -class Reshape(Function): - def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: return x.reshape(shape) - -class Permute(Function): - def forward(self, x:UOp, order:tuple[int, ...]) -> UOp: return x.permute(order) - -class Pad(Function): - def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp: return x.pad(arg) - -class Shrink(Function): - def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp: return x.shrink(arg) - -class Flip(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.stride(tuple([-1 if i in axis else 1 for i in range(len(x.shape))])) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1bcd22af53..1b81c74601 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,7 +2,7 @@ from __future__ import annotations import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref from contextlib import ContextDecorator -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex +from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap @@ -42,26 +42,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None: if s is ns: continue t.lazydata = ns -# **** start with two base classes, Tensor and Function **** - -class Function: - def __init__(self, device:Union[str, tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None): - self.device = device - self.needs_input_grad = [t.requires_grad for t in tensors] - self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False - if self.requires_grad: self.parents = tensors - self.metadata = metadata - - def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") - - @classmethod - def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor: - ctx = fxn(x[0].device, *x, metadata=_METADATA.get()) - ret = Tensor.__new__(Tensor) - ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None - return ret - -import tinygrad.function as F +# **** Tensor helper functions **** def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None): if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg) @@ -239,6 +220,17 @@ class Tensor(SimpleMathTrait): @property def dtype(self) -> DType: return self.lazydata.dtype + def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor: + ret = Tensor.__new__(Tensor) + needs_input_grad = [t.requires_grad for t in (self,)+x] + ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None + ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs) + return ret + + def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + lhs,rhs = self._broadcasted(x, reverse) + return lhs._apply_uop(fxn, rhs) + # ***** data handlers **** def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]: @@ -497,7 +489,7 @@ class Tensor(SimpleMathTrait): @staticmethod def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor): x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64) - x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64)) + x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64)) counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32) return counts0.cat(counts1) @@ -961,7 +953,7 @@ class Tensor(SimpleMathTrait): # resolve -1 if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) - return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self + return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self def expand(self, shape, *args) -> Tensor: """ @@ -994,7 +986,7 @@ class Tensor(SimpleMathTrait): """ order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args)) if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") - return F.Permute.apply(self, order=order_arg) + return self._apply_uop(UOp.permute, arg=order_arg) def flip(self, axis, *args) -> Tensor: """ @@ -1014,7 +1006,7 @@ class Tensor(SimpleMathTrait): """ axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args)) if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}") - return F.Flip.apply(self, axis=axis_arg) + return self._apply_uop(UOp.stride, arg=tuple([-1 if i in axis_arg else 1 for i in range(len(self.shape))])) def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor: """ @@ -1034,7 +1026,7 @@ class Tensor(SimpleMathTrait): ``` """ if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self - return F.Shrink.apply(self, arg=tuple(shrink_arg)) + return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg)) def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor: """ @@ -1078,7 +1070,8 @@ class Tensor(SimpleMathTrait): if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}") X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if mode == "constant": - def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0,v) + def _constant(x:Tensor,px,v): + return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v)) return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \ _constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value) assert all_int(self.shape), f"does not support symbolic shape {self.shape}" @@ -1568,10 +1561,10 @@ class Tensor(SimpleMathTrait): # ***** reduce ops ***** - def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: + def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1))) if self.ndim == 0: axis = () - ret = fxn.apply(self, axis=axis) + ret = self._apply_uop(UOp.r, op=op, axis=axis) return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis)) def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None): @@ -1598,7 +1591,7 @@ class Tensor(SimpleMathTrait): print(t.sum(axis=1).numpy()) ``` """ - ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim) + ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim) return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None): @@ -1625,7 +1618,7 @@ class Tensor(SimpleMathTrait): print(t.prod(axis=1).numpy()) ``` """ - return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim) + return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim) def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): """ @@ -1648,7 +1641,7 @@ class Tensor(SimpleMathTrait): print(t.max(axis=1, keepdim=True).numpy()) ``` """ - return self._reduce(F.Max, axis, keepdim) + return self._reduce(Ops.MAX, axis, keepdim) def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not() @@ -2485,7 +2478,7 @@ class Tensor(SimpleMathTrait): print(Tensor([False, True]).logical_not().numpy()) ``` """ - return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True)) + return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True) def neg(self): """ Negates the tensor element-wise. @@ -2499,12 +2492,12 @@ class Tensor(SimpleMathTrait): """ Returns a contiguous tensor. """ - return F.Contiguous.apply(self) + return self._apply_uop(UOp.contiguous) def contiguous_backward(self): """ Inserts a contiguous operation in the backward pass. """ - return F.ContiguousBackward.apply(self) + return self._apply_uop(UOp.contiguous_backward) def log(self): """ Computes the natural logarithm element-wise. @@ -2515,7 +2508,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 4., 8.]).log().numpy()) ``` """ - return F.Log.apply(self.cast(least_upper_float(self.dtype))) + return self.log2()*math.log(2) def log2(self): """ Computes the base-2 logarithm element-wise. @@ -2526,7 +2519,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 4., 8.]).log2().numpy()) ``` """ - return self.log()/math.log(2) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2) def exp(self): """ Computes the exponential function element-wise. @@ -2537,7 +2530,7 @@ class Tensor(SimpleMathTrait): print(Tensor([0., 1., 2., 3.]).exp().numpy()) ``` """ - return F.Exp.apply(self.cast(least_upper_float(self.dtype))) + return self.mul(1/math.log(2)).exp2() def exp2(self): """ Computes the base-2 exponential function element-wise. @@ -2548,8 +2541,7 @@ class Tensor(SimpleMathTrait): print(Tensor([0., 1., 2., 3.]).exp2().numpy()) ``` """ - return F.Exp.apply(self*math.log(2)) - + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2) def relu(self): """ Applies the Rectified Linear Unit (ReLU) function element-wise. @@ -2560,7 +2552,7 @@ class Tensor(SimpleMathTrait): print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy()) ``` """ - return F.Relu.apply(self) + return (self>0).where(self, 0) def sigmoid(self): """ @@ -2596,7 +2588,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 3., 4.]).sqrt().numpy()) ``` """ - return F.Sqrt.apply(self.cast(least_upper_float(self.dtype))) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt) def rsqrt(self): """ Computes the reciprocal of the square root of the tensor element-wise. @@ -2614,7 +2606,7 @@ class Tensor(SimpleMathTrait): print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy()) ``` """ - return F.Sin.apply(self.cast(least_upper_float(self.dtype))) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin) def cos(self): """ Computes the cosine of the tensor element-wise. @@ -2773,7 +2765,7 @@ class Tensor(SimpleMathTrait): print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy()) ``` """ - return F.Sign.apply(self) + return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0 def abs(self): """ Computes the absolute value of the tensor element-wise. @@ -2791,7 +2783,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 3., 4.]).reciprocal().numpy()) ``` """ - return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype))) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal) # ***** activation functions ***** @@ -3069,7 +3061,7 @@ class Tensor(SimpleMathTrait): # for each dimension, check either dim is 1, or it does not change if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)): raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") - return F.Expand.apply(self.reshape(shape), shape=new_shape) + return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape) def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]: x: Tensor = self @@ -3113,7 +3105,7 @@ class Tensor(SimpleMathTrait): print(t.add(Tensor([[2.0], [3.5]])).numpy()) ``` """ - return F.Add.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.add, x, reverse) def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3154,7 +3146,7 @@ class Tensor(SimpleMathTrait): print(t.mul(Tensor([[-1.0], [2.0]])).numpy()) ``` """ - return F.Mul.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.mul, x, reverse) def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3167,7 +3159,7 @@ class Tensor(SimpleMathTrait): print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy()) ``` """ - return F.IDiv.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.idiv, x, reverse) def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3202,7 +3194,7 @@ class Tensor(SimpleMathTrait): ``` """ a, b = self._broadcasted(x, reverse) - return (r := F.Mod.apply(a, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0))) + return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0))) def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3218,7 +3210,7 @@ class Tensor(SimpleMathTrait): ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") - return F.Xor.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.xor, x, reverse) def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3233,7 +3225,7 @@ class Tensor(SimpleMathTrait): ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") - return F.BitwiseAnd.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse) def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3248,7 +3240,7 @@ class Tensor(SimpleMathTrait): ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") - return F.BitwiseOr.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse) def bitwise_not(self) -> Tensor: """ @@ -3379,7 +3371,7 @@ class Tensor(SimpleMathTrait): elif isinstance(y, Tensor): y, x = y._broadcasted(x) cond, x = self._broadcasted(x, match_dtype=False) cond, y = cond._broadcasted(y, match_dtype=False) - return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y)) + return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y)) def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self) @@ -3409,9 +3401,9 @@ class Tensor(SimpleMathTrait): def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x)) def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x)) - def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False)) - def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True)) - def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) + def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False) + def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True) + def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False) def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override] @@ -3757,8 +3749,8 @@ class Tensor(SimpleMathTrait): """ if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype): # NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around - return F.Cast.apply(F.Cast.apply(self, dtype=dtypes.int32), dtype=dt) - return self if self.dtype == dt else F.Cast.apply(self, dtype=dt) + return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt) + return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt) def bitcast(self, dtype:DTypeLike) -> Tensor: """ @@ -3783,7 +3775,7 @@ class Tensor(SimpleMathTrait): tmp = self.bitcast(old_uint) if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype) return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype) - return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self + return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self def float(self) -> Tensor: """