diff --git a/tinygrad/mixin/dtype.py b/tinygrad/mixin/dtype.py index fccfb58da0..9094925207 100644 --- a/tinygrad/mixin/dtype.py +++ b/tinygrad/mixin/dtype.py @@ -8,11 +8,26 @@ class DTypeMixin: def cast(self, dtype:DType) -> Self: raise NotImplementedError def element_size(self) -> int: - """Returns the number of bytes of a single element in the tensor.""" + """ + Returns the size in bytes of an individual element in the tensor. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([5], dtype=dtypes.int16) + print(t.element_size()) + ``` + """ return self.dtype.itemsize def is_floating_point(self) -> bool: - """Returns `True` if the tensor contains floating point types, i.e. is one of `bool`, `float16`, `bfloat16`, `float32`, `float64`.""" + """ + Returns `True` if the tensor contains floating point types, i.e. is one of `dtypes.float64`, `dtypes.float32`, + `dtypes.float16`, `dtypes.bfloat16`. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([8, 9], dtype=dtypes.float32) + print(t.is_floating_point()) + ``` + """ return dtypes.is_float(self.dtype) def float(self) -> Self: diff --git a/tinygrad/mixin/math.py b/tinygrad/mixin/math.py index 7934d09efa..6b59480bd2 100644 --- a/tinygrad/mixin/math.py +++ b/tinygrad/mixin/math.py @@ -195,10 +195,10 @@ class MathMixin(DTypeMixin): return self.mod(x, True) def __lt__(self, x: Self | ConstType) -> Self: - return self.alu(Ops.CMPLT, self.ufix(x)) + return self._binop(Ops.CMPLT, x, False) def __gt__(self, x: Self | ConstType) -> Self: - return self.ufix(x).alu(Ops.CMPLT, self) + return self._binop(Ops.CMPLT, x, True) def __ge__(self, x: Self | ConstType) -> Self: return (self < x).logical_not() @@ -207,7 +207,7 @@ class MathMixin(DTypeMixin): return (self > x).logical_not() def ne(self, x: Self | ConstType) -> Self: - return self.alu(Ops.CMPNE, self.ufix(x)) + return self._binop(Ops.CMPNE, x, False) def eq(self, x: Self | ConstType) -> Self: return self.ne(x).logical_not() @@ -236,7 +236,17 @@ class MathMixin(DTypeMixin): return self.rshift(x, True) def maximum(self, x: Self | ConstType) -> Self: - return self.alu(Ops.MAX, self.ufix(x)) + """ + Computes element-wise maximum of `self` and `x`. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-1, 2, 3]).maximum(1).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy()) + ``` + """ + return self._binop(Ops.MAX, x, False) def minimum(self, x: Self | ConstType) -> Self: return -(-self).maximum(-self.ufix(x)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 57e373ece9..0b5733166b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -189,12 +189,10 @@ class Tensor(OpMixin): all_tensors[weakref.ref(ret)] = None return ret - def _apply_broadcasted_uop(self, fxn:Callable, x:Tensor|ConstType, reverse=False) -> Tensor: - lhs,rhs = self._broadcasted(x, reverse) - return lhs._apply_uop(fxn, rhs) - # _binop and alu are used by MathMixin - def _binop(self, op, x, reverse): return self._apply_broadcasted_uop(lambda *u: UOp.alu(u[0], op, *u[1:]), x, reverse) + def _binop(self, op, x, reverse): + lhs,rhs = self._broadcasted(x, reverse) + return lhs._apply_uop(lambda *u: u[0].alu(op, *u[1:]), rhs) def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src) def requires_grad_(self, requires_grad=True) -> Tensor: @@ -2822,7 +2820,7 @@ class Tensor(OpMixin): print(Tensor([False, True]).logical_not().numpy()) ``` """ - return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True) + return self.cast(dtypes.bool).ne(True) def neg(self) -> Tensor: """ @@ -3197,7 +3195,7 @@ class Tensor(OpMixin): numerator, denominator = numerator.cast(dt), denominator.cast(dt) if rounding_mode == "trunc": return numerator.idiv(denominator) if rounding_mode == "floor": - truncate_div, truncate_mod = numerator.idiv(denominator), numerator._apply_broadcasted_uop(UOp.mod, denominator) + truncate_div, truncate_mod = numerator.idiv(denominator), numerator._binop(Ops.MOD, denominator, False) opposite_sign = ((numerator>0)&(denominator<0)) | ((numerator<0)&(denominator>0)) return (opposite_sign&(truncate_mod!=0)).where(truncate_div-1, truncate_div) if rounding_mode == "trunc": return d.trunc().cast(output_dtype) @@ -3279,19 +3277,6 @@ class Tensor(OpMixin): # NOTE: pow(int, float) -> int return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) and dtypes.is_float(exponent.dtype) else ret - def maximum(self, x:Tensor|ConstType) -> Tensor: - """ - Computes element-wise maximum of `self` and `x`. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-1, 2, 3]).maximum(1).numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy()) - ``` - """ - return self._apply_broadcasted_uop(UOp.maximum, x) - def minimum(self, x:Tensor|ConstType) -> Tensor: """ Computes element-wise minimum of `self` and `x`. @@ -3376,10 +3361,6 @@ class Tensor(OpMixin): def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x)) # type: ignore[misc] def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x)) # type: ignore[misc] - def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False) - def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True) - def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False) - def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override] # ***** encoding/decoding ops ***** @@ -3743,17 +3724,6 @@ class Tensor(OpMixin): # ***** Tensor Properties ***** - def element_size(self) -> int: - """ - Returns the size in bytes of an individual element in the tensor. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([5], dtype=dtypes.int16) - print(t.element_size()) - ``` - """ - return self.dtype.itemsize - def nbytes(self) -> int: """ Returns the total number of bytes of all elements in the tensor. @@ -3765,18 +3735,6 @@ class Tensor(OpMixin): """ return int(self.numel()) * self.element_size() - def is_floating_point(self) -> bool: - """ - Returns `True` if the tensor contains floating point types, i.e. is one of `dtypes.float64`, `dtypes.float32`, - `dtypes.float16`, `dtypes.bfloat16`. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([8, 9], dtype=dtypes.float32) - print(t.is_floating_point()) - ``` - """ - return dtypes.is_float(self.dtype) - def size(self, dim:int|None=None) -> sint|tuple[sint, ...]: """ Returns the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.