mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Tensor.logaddexp (#11793)
This commit is contained in:
@@ -78,6 +78,7 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
|
|||||||
::: tinygrad.Tensor.minimum
|
::: tinygrad.Tensor.minimum
|
||||||
::: tinygrad.Tensor.where
|
::: tinygrad.Tensor.where
|
||||||
::: tinygrad.Tensor.copysign
|
::: tinygrad.Tensor.copysign
|
||||||
|
::: tinygrad.Tensor.logaddexp
|
||||||
|
|
||||||
## Casting Ops
|
## Casting Ops
|
||||||
|
|
||||||
|
|||||||
@@ -381,6 +381,7 @@ decomps = [
|
|||||||
aten.elu, # elu has a scale + input_scale param
|
aten.elu, # elu has a scale + input_scale param
|
||||||
aten.elu_backward,
|
aten.elu_backward,
|
||||||
aten.softplus,
|
aten.softplus,
|
||||||
|
aten.logaddexp,
|
||||||
aten.threshold,
|
aten.threshold,
|
||||||
aten.nll_loss_forward,
|
aten.nll_loss_forward,
|
||||||
aten.nll_loss_backward,
|
aten.nll_loss_backward,
|
||||||
|
|||||||
@@ -928,6 +928,12 @@ class TestOps(unittest.TestCase):
|
|||||||
for j in [-1., 0., 1.]:
|
for j in [-1., 0., 1.]:
|
||||||
helper_test_op(None, torch.copysign, Tensor.copysign, vals=[[i], [j]])
|
helper_test_op(None, torch.copysign, Tensor.copysign, vals=[[i], [j]])
|
||||||
|
|
||||||
|
def test_logaddexp(self):
|
||||||
|
helper_test_op([(45,65), (45,65)], torch.logaddexp, Tensor.logaddexp)
|
||||||
|
helper_test_op(None, torch.logaddexp, Tensor.logaddexp, vals=[[-1.], [-1.0, 2, 3]])
|
||||||
|
helper_test_op(None, torch.logaddexp, Tensor.logaddexp, vals=[[-100.0, -200, -300], [-1.0, 2, 3]])
|
||||||
|
helper_test_op(None, torch.logaddexp, Tensor.logaddexp, vals=[[1.0, 2000, 30000], [-1.0, 2, 3]])
|
||||||
|
|
||||||
def test_softsign(self):
|
def test_softsign(self):
|
||||||
helper_test_op([(45,65)], torch.nn.functional.softsign, Tensor.softsign)
|
helper_test_op([(45,65)], torch.nn.functional.softsign, Tensor.softsign)
|
||||||
helper_test_op([()], torch.nn.functional.softsign, Tensor.softsign)
|
helper_test_op([()], torch.nn.functional.softsign, Tensor.softsign)
|
||||||
|
|||||||
@@ -3759,6 +3759,13 @@ class Tensor(MathTrait):
|
|||||||
# TODO: remove other*0?
|
# TODO: remove other*0?
|
||||||
return (other < 0).where(-self.abs(), self.abs()) + other*0
|
return (other < 0).where(-self.abs(), self.abs()) + other*0
|
||||||
|
|
||||||
|
def logaddexp(self, other) -> Tensor:
|
||||||
|
"""
|
||||||
|
Calculates (self.exp()+other.exp()).log(), elementwise.
|
||||||
|
"""
|
||||||
|
m = self.maximum(other)
|
||||||
|
return ((self-m).exp() + (self._broadcasted(other)[1]-m).exp()).log() + m
|
||||||
|
|
||||||
# ***** op wrappers *****
|
# ***** op wrappers *****
|
||||||
|
|
||||||
def __invert__(self) -> Tensor: return self.bitwise_not()
|
def __invert__(self) -> Tensor: return self.bitwise_not()
|
||||||
|
|||||||
Reference in New Issue
Block a user