mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
Tensor.copysign (#9329)
This commit is contained in:
@@ -263,8 +263,7 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
||||
"aten.log10.out": lambda self: self.log2() * (math.log(2) / math.log(10)),
|
||||
"aten.log1p.out": lambda self: (self+1).log(),
|
||||
"aten.expm1.out": lambda self: self.exp() - 1,
|
||||
# TODO: move to tinygrad
|
||||
"aten.copysign.out": lambda input,other: input.abs() * other.sign(),
|
||||
"aten.copysign.out": Tensor.copysign,
|
||||
# TODO: this gets the shape wrong
|
||||
#"aten.arange.start_out": Tensor.arange,
|
||||
"aten.lerp.Scalar_out": Tensor.lerp,
|
||||
|
||||
@@ -891,6 +891,16 @@ class TestOps(unittest.TestCase):
|
||||
def test_sign_exact(self):
|
||||
helper_test_op(None, torch.sign, Tensor.sign, vals=[[-1.,0,1]])
|
||||
|
||||
def test_copysign(self):
|
||||
helper_test_op([(45,65), (45,65)], torch.copysign, Tensor.copysign)
|
||||
helper_test_op([(45,65), (45,1)], torch.copysign, Tensor.copysign)
|
||||
helper_test_op([(45,1), (1,65)], torch.copysign, Tensor.copysign)
|
||||
helper_test_op([(), ()], torch.copysign, Tensor.copysign)
|
||||
def test_copysign_exact(self):
|
||||
for i in [-1.,0.,1.]:
|
||||
for j in [-1., 0., 1.]:
|
||||
helper_test_op(None, torch.copysign, Tensor.copysign, vals=[[i], [j]])
|
||||
|
||||
def test_softsign(self):
|
||||
helper_test_op([(45,65)], torch.nn.functional.softsign, Tensor.softsign)
|
||||
helper_test_op([()], torch.nn.functional.softsign, Tensor.softsign)
|
||||
|
||||
@@ -3460,6 +3460,15 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
|
||||
|
||||
def copysign(self, other) -> Tensor:
|
||||
"""
|
||||
Return a tensor of with the magnitude of `self` and the sign of `other`, elementwise.
|
||||
"""
|
||||
# NOTE: torch always return in float, we return based on the broadcasting rule.
|
||||
other = self._broadcasted(other)[1]
|
||||
# TODO: remove other*0?
|
||||
return (other < 0).where(-self.abs(), self.abs()) + other*0
|
||||
|
||||
# ***** op wrappers *****
|
||||
|
||||
def __invert__(self) -> Tensor: return self.bitwise_not()
|
||||
|
||||
Reference in New Issue
Block a user