diff --git a/tinygrad/mixin/elementwise.py b/tinygrad/mixin/elementwise.py index 2f376c578d..549f162860 100644 --- a/tinygrad/mixin/elementwise.py +++ b/tinygrad/mixin/elementwise.py @@ -287,8 +287,36 @@ class ElementwiseMixin(DTypeMixin): """ return self._binop(Ops.MAX, x, False) + def _inverse(self) -> Self: return -self if self.is_floating_point() else ~self + def minimum(self, x: Self | ConstType) -> Self: - return -(-self).maximum(-self.ufix(x)) + """ + Computes element-wise minimum of `self` and `x`. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-1, 2, 3]).minimum(1).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy()) + ``` + """ + t, x = self._broadcasted(x) + return t._inverse().maximum(x._inverse())._inverse() + + def copysign(self, other: Self | ConstType) -> Self: + """ + Returns a tensor of with the magnitude of `self` and the sign of `other`, elementwise. + """ + # NOTE: torch always return in float, we return based on the broadcasting rule. + other = self._broadcasted(other)[1] + return self.abs() * ((other < 0) | (other.reciprocal() < 0)).where(-1, 1) + + def logaddexp(self, other: Self | ConstType) -> Self: + """ + Calculates (self.exp()+other.exp()).log(), elementwise. + """ + m = self.maximum(other) + return ((self-m).exp() + (self._broadcasted(other)[1]-m).exp()).log() + m def where(self, x: Self | ConstType, y: Self | ConstType) -> Self: if isinstance(x, type(self)): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9a0f305ae2..427a7e44ec 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1699,8 +1699,6 @@ class Tensor(OpMixin): """ return self._reduce(Ops.MAX, axis, keepdim) - def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self - def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor: """ Returns the minimum value of the tensor along the specified axis or axes. @@ -3013,20 +3011,6 @@ class Tensor(OpMixin): assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0 and not reverse, f"not supported {self.dtype=} {x=}" return self.idiv(2 ** x, reverse) - def minimum(self, x:Tensor|ConstType) -> Tensor: - """ - Computes element-wise minimum of `self` and `x`. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-1, 2, 3]).minimum(1).numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy()) - ``` - """ - t, x = self._broadcasted(x) - return t._inverse().maximum(x._inverse())._inverse() - def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor: """ Returns a tensor of elements selected from either `x` or `y`, depending on `self`. @@ -3051,21 +3035,6 @@ class Tensor(OpMixin): out_shape = _broadcast_shape(self.shape, x.shape) return self.cast(dtypes.bool)._broadcast_to(out_shape)._apply_uop(UOp.where, x._broadcast_to(out_shape), y._broadcast_to(out_shape)) - def copysign(self, other) -> Tensor: - """ - Returns a tensor of with the magnitude of `self` and the sign of `other`, elementwise. - """ - # NOTE: torch always return in float, we return based on the broadcasting rule. - other = self._broadcasted(other)[1] - return self.abs() * ((other < 0) | (other.reciprocal() < 0)).where(-1, 1) - - def logaddexp(self, other) -> Tensor: - """ - Calculates (self.exp()+other.exp()).log(), elementwise. - """ - m = self.maximum(other) - return ((self-m).exp() + (self._broadcasted(other)[1]-m).exp()).log() + m - # ***** op wrappers ***** # TODO: combine with UOps __floordiv__ diff --git a/tinygrad/uop/validate.py b/tinygrad/uop/validate.py index 5c91e853af..776460c93d 100644 --- a/tinygrad/uop/validate.py +++ b/tinygrad/uop/validate.py @@ -10,9 +10,12 @@ if z3.get_version() < (4, 12, 4, 0): # IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND def z3_cdiv(a:z3.ArithRef, b:z3.ArithRef) -> z3.ArithRef:return z3.If((a<0), z3.If(0 z3.BoolRef: - assert isinstance(a, z3.BoolRef), f"{type(a)=}, {a=}" - return a^b +def z3_xor(a:z3.ExprRef, b:z3.ExprRef) -> z3.ExprRef: + if isinstance(a, z3.BoolRef): return a^b + # x ^ -1 = -(x+1), i.e. bitwise NOT + if isinstance(b, z3.IntNumRef) and b.as_long() == -1: return -(a+1) + if isinstance(a, z3.IntNumRef) and a.as_long() == -1: return -(b+1) + raise RuntimeError(f"z3 int XOR only supports XOR with -1, got {a=} {b=}") z3_alu: dict[Ops, Callable[..., z3.ExprRef]] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()), Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor, Ops.MAX: lambda a,b: z3.If(a