Tensor.logaddexp (#11793)

This commit is contained in:
chenyu
2025-08-23 09:15:00 -04:00
committed by GitHub
parent 5a6817d5f8
commit fb8ee02424
4 changed files with 15 additions and 0 deletions

View File

@@ -78,6 +78,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
::: tinygrad.Tensor.minimum
::: tinygrad.Tensor.where
::: tinygrad.Tensor.copysign
::: tinygrad.Tensor.logaddexp
## Casting Ops

View File

@@ -381,6 +381,7 @@ decomps = [
aten.elu, # elu has a scale + input_scale param
aten.elu_backward,
aten.softplus,
aten.logaddexp,
aten.threshold,
aten.nll_loss_forward,
aten.nll_loss_backward,

View File

@@ -928,6 +928,12 @@ class TestOps(unittest.TestCase):
for j in [-1., 0., 1.]:
helper_test_op(None, torch.copysign, Tensor.copysign, vals=[[i], [j]])
def test_logaddexp(self):
helper_test_op([(45,65), (45,65)], torch.logaddexp, Tensor.logaddexp)
helper_test_op(None, torch.logaddexp, Tensor.logaddexp, vals=[[-1.], [-1.0, 2, 3]])
helper_test_op(None, torch.logaddexp, Tensor.logaddexp, vals=[[-100.0, -200, -300], [-1.0, 2, 3]])
helper_test_op(None, torch.logaddexp, Tensor.logaddexp, vals=[[1.0, 2000, 30000], [-1.0, 2, 3]])
def test_softsign(self):
helper_test_op([(45,65)], torch.nn.functional.softsign, Tensor.softsign)
helper_test_op([()], torch.nn.functional.softsign, Tensor.softsign)

View File

@@ -3759,6 +3759,13 @@ class Tensor(MathTrait):
# TODO: remove other*0?
return (other < 0).where(-self.abs(), self.abs()) + other*0
def logaddexp(self, other) -> Tensor:
"""
Calculates (self.exp()+other.exp()).log(), elementwise.
"""
m = self.maximum(other)
return ((self-m).exp() + (self._broadcasted(other)[1]-m).exp()).log() + m
# ***** op wrappers *****
def __invert__(self) -> Tensor: return self.bitwise_not()