mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move more broadcast method to mixin [pr] (#15513)
* move more broadcast method to mixin [pr] all but div, mod, and where * xor -1
This commit is contained in:
@@ -287,8 +287,36 @@ class ElementwiseMixin(DTypeMixin):
|
||||
"""
|
||||
return self._binop(Ops.MAX, x, False)
|
||||
|
||||
def _inverse(self) -> Self: return -self if self.is_floating_point() else ~self
|
||||
|
||||
def minimum(self, x: Self | ConstType) -> Self:
|
||||
return -(-self).maximum(-self.ufix(x))
|
||||
"""
|
||||
Computes element-wise minimum of `self` and `x`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-1, 2, 3]).minimum(1).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
|
||||
```
|
||||
"""
|
||||
t, x = self._broadcasted(x)
|
||||
return t._inverse().maximum(x._inverse())._inverse()
|
||||
|
||||
def copysign(self, other: Self | ConstType) -> Self:
|
||||
"""
|
||||
Returns a tensor of with the magnitude of `self` and the sign of `other`, elementwise.
|
||||
"""
|
||||
# NOTE: torch always return in float, we return based on the broadcasting rule.
|
||||
other = self._broadcasted(other)[1]
|
||||
return self.abs() * ((other < 0) | (other.reciprocal() < 0)).where(-1, 1)
|
||||
|
||||
def logaddexp(self, other: Self | ConstType) -> Self:
|
||||
"""
|
||||
Calculates (self.exp()+other.exp()).log(), elementwise.
|
||||
"""
|
||||
m = self.maximum(other)
|
||||
return ((self-m).exp() + (self._broadcasted(other)[1]-m).exp()).log() + m
|
||||
|
||||
def where(self, x: Self | ConstType, y: Self | ConstType) -> Self:
|
||||
if isinstance(x, type(self)):
|
||||
|
||||
@@ -1699,8 +1699,6 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
return self._reduce(Ops.MAX, axis, keepdim)
|
||||
|
||||
def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self
|
||||
|
||||
def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
Returns the minimum value of the tensor along the specified axis or axes.
|
||||
@@ -3013,20 +3011,6 @@ class Tensor(OpMixin):
|
||||
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0 and not reverse, f"not supported {self.dtype=} {x=}"
|
||||
return self.idiv(2 ** x, reverse)
|
||||
|
||||
def minimum(self, x:Tensor|ConstType) -> Tensor:
|
||||
"""
|
||||
Computes element-wise minimum of `self` and `x`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-1, 2, 3]).minimum(1).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-1, 2, 3]).minimum(Tensor([-4, -2, 9])).numpy())
|
||||
```
|
||||
"""
|
||||
t, x = self._broadcasted(x)
|
||||
return t._inverse().maximum(x._inverse())._inverse()
|
||||
|
||||
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
|
||||
"""
|
||||
Returns a tensor of elements selected from either `x` or `y`, depending on `self`.
|
||||
@@ -3051,21 +3035,6 @@ class Tensor(OpMixin):
|
||||
out_shape = _broadcast_shape(self.shape, x.shape)
|
||||
return self.cast(dtypes.bool)._broadcast_to(out_shape)._apply_uop(UOp.where, x._broadcast_to(out_shape), y._broadcast_to(out_shape))
|
||||
|
||||
def copysign(self, other) -> Tensor:
|
||||
"""
|
||||
Returns a tensor of with the magnitude of `self` and the sign of `other`, elementwise.
|
||||
"""
|
||||
# NOTE: torch always return in float, we return based on the broadcasting rule.
|
||||
other = self._broadcasted(other)[1]
|
||||
return self.abs() * ((other < 0) | (other.reciprocal() < 0)).where(-1, 1)
|
||||
|
||||
def logaddexp(self, other) -> Tensor:
|
||||
"""
|
||||
Calculates (self.exp()+other.exp()).log(), elementwise.
|
||||
"""
|
||||
m = self.maximum(other)
|
||||
return ((self-m).exp() + (self._broadcasted(other)[1]-m).exp()).log() + m
|
||||
|
||||
# ***** op wrappers *****
|
||||
|
||||
# TODO: combine with UOps __floordiv__
|
||||
|
||||
@@ -10,9 +10,12 @@ if z3.get_version() < (4, 12, 4, 0):
|
||||
|
||||
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
|
||||
def z3_cdiv(a:z3.ArithRef, b:z3.ArithRef) -> z3.ArithRef:return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
|
||||
def z3_xor(a:z3.BoolRef, b:z3.BoolRef) -> z3.BoolRef:
|
||||
assert isinstance(a, z3.BoolRef), f"{type(a)=}, {a=}"
|
||||
return a^b
|
||||
def z3_xor(a:z3.ExprRef, b:z3.ExprRef) -> z3.ExprRef:
|
||||
if isinstance(a, z3.BoolRef): return a^b
|
||||
# x ^ -1 = -(x+1), i.e. bitwise NOT
|
||||
if isinstance(b, z3.IntNumRef) and b.as_long() == -1: return -(a+1)
|
||||
if isinstance(a, z3.IntNumRef) and a.as_long() == -1: return -(b+1)
|
||||
raise RuntimeError(f"z3 int XOR only supports XOR with -1, got {a=} {b=}")
|
||||
z3_alu: dict[Ops, Callable[..., z3.ExprRef]] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv,
|
||||
Ops.SHR: lambda a,b: a/(2**b.as_long()), Ops.SHL: lambda a,b: a*(2**b.as_long()),
|
||||
Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor, Ops.MAX: lambda a,b: z3.If(a<b, b, a),}
|
||||
|
||||
Reference in New Issue
Block a user