fix copysign for -0 (#14380)

test both x and 1/x < 0 work too. and found another big with the * 0 hack
This commit is contained in:
chenyu
2026-01-27 15:44:58 -05:00
committed by GitHub
parent 62884585a7
commit 8c899e4aaf
2 changed files with 4 additions and 8 deletions

View File

@@ -948,14 +948,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,1), (1,65)], torch.copysign, Tensor.copysign)
helper_test_op([(), ()], torch.copysign, Tensor.copysign)
def test_copysign_exact(self):
for i in [-1.,0.,1.]:
for j in [-1., 0., 1.]:
v = [-1., -0., 0., 1.]#, math.inf, -math.inf] # TODO: * 0 hack does not work with inf
for i in v:
for j in v:
helper_test_op(None, torch.copysign, Tensor.copysign, vals=[[i], [j]])
# TODO: fix copysign to distinguish between -0.0 and 0.0
@unittest.skipIf(COMPILE_ONLY or getenv("TINY_BACKEND"), "test requires runtime")
@unittest.expectedFailure
def test_copysign_signed_zero(self):
helper_test_op(None, torch.copysign, Tensor.copysign, vals=[[1.0, 1.0], [-0.0, 0.0]])
def test_logaddexp(self):
helper_test_op([(45,65), (45,65)], torch.logaddexp, Tensor.logaddexp)

View File

@@ -3375,7 +3375,7 @@ class Tensor(OpMixin):
# NOTE: torch always return in float, we return based on the broadcasting rule.
other = self._broadcasted(other)[1]
# TODO: remove other*0?
return (other < 0).where(-self.abs(), self.abs()) + other*0
return ((other < 0) | (other.reciprocal() < 0)).where(-self.abs(), self.abs()) + other*0
def logaddexp(self, other) -> Tensor:
"""