move more broadcast method to mixin [pr] (#15513)

* move more broadcast method to mixin [pr]

all but div, mod, and where

* xor -1
This commit is contained in:
chenyu
2026-03-28 01:48:08 -04:00
committed by GitHub
parent c0753ab62f
commit fe705def0d
3 changed files with 35 additions and 35 deletions

View File

@@ -287,8 +287,36 @@ class ElementwiseMixin(DTypeMixin):
"""
return self._binop(Ops.MAX, x, False)
def _inverse(self) -> Self: return -self if self.is_floating_point() else ~self
def minimum(self, x: Self | ConstType) -> Self:
return -(-self).maximum(-self.ufix(x))
"""
Computes element-wise minimum of `self` and `x`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).minimum(1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
```
"""
t, x = self._broadcasted(x)
return t._inverse().maximum(x._inverse())._inverse()
def copysign(self, other: Self | ConstType) -> Self:
"""
Returns a tensor of with the magnitude of `self` and the sign of `other`, elementwise.
"""
# NOTE: torch always return in float, we return based on the broadcasting rule.
other = self._broadcasted(other)[1]
return self.abs() * ((other < 0) | (other.reciprocal() < 0)).where(-1, 1)
def logaddexp(self, other: Self | ConstType) -> Self:
"""
Calculates (self.exp()+other.exp()).log(), elementwise.
"""
m = self.maximum(other)
return ((self-m).exp() + (self._broadcasted(other)[1]-m).exp()).log() + m
def where(self, x: Self | ConstType, y: Self | ConstType) -> Self:
if isinstance(x, type(self)):

View File

@@ -1699,8 +1699,6 @@ class Tensor(OpMixin):
"""
return self._reduce(Ops.MAX, axis, keepdim)
def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self
def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Returns the minimum value of the tensor along the specified axis or axes.
@@ -3013,20 +3011,6 @@ class Tensor(OpMixin):
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0 and not reverse, f"not supported {self.dtype=} {x=}"
return self.idiv(2 ** x, reverse)
def minimum(self, x:Tensor|ConstType) -> Tensor:
"""
Computes element-wise minimum of `self` and `x`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).minimum(1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
```
"""
t, x = self._broadcasted(x)
return t._inverse().maximum(x._inverse())._inverse()
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
"""
Returns a tensor of elements selected from either `x` or `y`, depending on `self`.
@@ -3051,21 +3035,6 @@ class Tensor(OpMixin):
out_shape = _broadcast_shape(self.shape, x.shape)
return self.cast(dtypes.bool)._broadcast_to(out_shape)._apply_uop(UOp.where, x._broadcast_to(out_shape), y._broadcast_to(out_shape))
def copysign(self, other) -> Tensor:
"""
Returns a tensor of with the magnitude of `self` and the sign of `other`, elementwise.
"""
# NOTE: torch always return in float, we return based on the broadcasting rule.
other = self._broadcasted(other)[1]
return self.abs() * ((other < 0) | (other.reciprocal() < 0)).where(-1, 1)
def logaddexp(self, other) -> Tensor:
"""
Calculates (self.exp()+other.exp()).log(), elementwise.
"""
m = self.maximum(other)
return ((self-m).exp() + (self._broadcasted(other)[1]-m).exp()).log() + m
# ***** op wrappers *****
# TODO: combine with UOps __floordiv__

View File

@@ -10,9 +10,12 @@ if z3.get_version() < (4, 12, 4, 0):
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
def z3_cdiv(a:z3.ArithRef, b:z3.ArithRef) -> z3.ArithRef:return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
def z3_xor(a:z3.BoolRef, b:z3.BoolRef) -> z3.BoolRef:
assert isinstance(a, z3.BoolRef), f"{type(a)=}, {a=}"
return a^b
def z3_xor(a:z3.ExprRef, b:z3.ExprRef) -> z3.ExprRef:
if isinstance(a, z3.BoolRef): return a^b
# x ^ -1 = -(x+1), i.e. bitwise NOT
if isinstance(b, z3.IntNumRef) and b.as_long() == -1: return -(a+1)
if isinstance(a, z3.IntNumRef) and a.as_long() == -1: return -(b+1)
raise RuntimeError(f"z3 int XOR only supports XOR with -1, got {a=} {b=}")
z3_alu: dict[Ops, Callable[..., z3.ExprRef]] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv,
Ops.SHR: lambda a,b: a/(2**b.as_long()), Ops.SHL: lambda a,b: a*(2**b.as_long()),
Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor, Ops.MAX: lambda a,b: z3.If(a<b, b, a),}