add support for sign. technically relu can be second class now

This commit is contained in:
George Hotz
2021-01-03 08:29:57 -08:00
parent 6842ad9ec8
commit c2eeb6950b
5 changed files with 34 additions and 4 deletions

View File

@@ -62,10 +62,16 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log)
def test_exp(self):
helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp)
def test_sign(self):
helper_test_op([(45,65)], lambda x: torch.sign(x), Tensor.sign)
def test_sigmoid(self):
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid)
def test_softplus(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6)
def test_relu6(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6)
def test_hardswish(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6)
def test_mish(self):
def _mish_pytorch(x):
return x*torch.tanh(torch.nn.functional.softplus(x))