update Tensor.maximum (#8992)

now it's just broadcast and UOp.maximum
This commit is contained in:
chenyu
2025-02-09 21:26:27 -05:00
committed by GitHub
parent 88add71c25
commit 9119716761

View File

@@ -3307,9 +3307,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
```
"""
# NOTE: the mid-point is for backward, revisit after new gradient API
if self.is_floating_point(): return (self<x).detach().where(x, (self==x).detach().where(((self * 0.5 + x * 0.5).cast(self.dtype)), self))
return (self<x).detach().where(x, self)
return self._apply_broadcasted_uop(UOp.maximum, x)
def minimum(self, x:Union[Tensor, ConstType]) -> Tensor:
"""