mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix copysign for -0 (#14380)
test both x and 1/x < 0 work too. and found another big with the * 0 hack
This commit is contained in:
@@ -948,14 +948,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,1), (1,65)], torch.copysign, Tensor.copysign)
|
||||
helper_test_op([(), ()], torch.copysign, Tensor.copysign)
|
||||
def test_copysign_exact(self):
|
||||
for i in [-1.,0.,1.]:
|
||||
for j in [-1., 0., 1.]:
|
||||
v = [-1., -0., 0., 1.]#, math.inf, -math.inf] # TODO: * 0 hack does not work with inf
|
||||
for i in v:
|
||||
for j in v:
|
||||
helper_test_op(None, torch.copysign, Tensor.copysign, vals=[[i], [j]])
|
||||
# TODO: fix copysign to distinguish between -0.0 and 0.0
|
||||
@unittest.skipIf(COMPILE_ONLY or getenv("TINY_BACKEND"), "test requires runtime")
|
||||
@unittest.expectedFailure
|
||||
def test_copysign_signed_zero(self):
|
||||
helper_test_op(None, torch.copysign, Tensor.copysign, vals=[[1.0, 1.0], [-0.0, 0.0]])
|
||||
|
||||
def test_logaddexp(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.logaddexp, Tensor.logaddexp)
|
||||
|
||||
@@ -3375,7 +3375,7 @@ class Tensor(OpMixin):
|
||||
# NOTE: torch always return in float, we return based on the broadcasting rule.
|
||||
other = self._broadcasted(other)[1]
|
||||
# TODO: remove other*0?
|
||||
return (other < 0).where(-self.abs(), self.abs()) + other*0
|
||||
return ((other < 0) | (other.reciprocal() < 0)).where(-self.abs(), self.abs()) + other*0
|
||||
|
||||
def logaddexp(self, other) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user