mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
test_ops sinh/cosh/asinh/acosh/atanh (#3294)
some have numerical issues at large input similar to sigmoid
This commit is contained in:
@@ -23,13 +23,13 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
ret = tinygrad_fxn(*tst).realize()
|
||||
tinygrad_fp = time.monotonic() - st
|
||||
|
||||
def compare(s, x,y,atol,rtol):
|
||||
if PRINT_TENSORS: print(s, x, y)
|
||||
assert x.shape == y.shape, f"shape mismatch: tinygrad={x.shape} | torch={y.shape}"
|
||||
def compare(s, tinygrad_output, torch_output, atol, rtol):
|
||||
if PRINT_TENSORS: print(s, tinygrad_output, torch_output)
|
||||
assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}"
|
||||
try:
|
||||
np.testing.assert_allclose(x,y, atol=atol, rtol=rtol)
|
||||
except Exception:
|
||||
raise Exception(f"{s} failed shape {x.shape}")
|
||||
np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol)
|
||||
except Exception as e:
|
||||
raise Exception(f"{s} failed shape {tinygrad_output.shape}: {e}")
|
||||
|
||||
if DEBUG >= 6:
|
||||
np.set_printoptions(linewidth=200, suppress=True)
|
||||
@@ -707,14 +707,38 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7)
|
||||
|
||||
def test_sinh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6)
|
||||
# TODO: backward nan instead of inf
|
||||
helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6, a=-100, forward_only=True)
|
||||
helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6, a=100, forward_only=True)
|
||||
def test_cosh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6)
|
||||
# TODO: backward nan instead of inf
|
||||
helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6, a=-100, forward_only=True)
|
||||
helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6, a=100, forward_only=True)
|
||||
def test_tanh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, a=-100)
|
||||
helper_test_op([()], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, a=100)
|
||||
def test_hardtanh(self):
|
||||
for val in range(10, 30, 5):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
|
||||
def test_asinh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-6, grad_atol=1e-6)
|
||||
# NOTE: this one has larger atol
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, a=-100)
|
||||
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-6, grad_atol=1e-6, a=100)
|
||||
def test_acosh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6, a=-100)
|
||||
helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6, a=100)
|
||||
def test_atanh(self):
|
||||
helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6, a=-100)
|
||||
helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6, a=100)
|
||||
|
||||
def test_topo_sort(self):
|
||||
helper_test_op([(45,65)], lambda x: (x+x)*x, atol=1e-6, grad_atol=1e-6)
|
||||
helper_test_op([()], lambda x: (x+x)*x, atol=1e-6, grad_atol=1e-6)
|
||||
|
||||
Reference in New Issue
Block a user