From b110c4a7b8d2db09eb2d169d4dc294c55cf55584 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 8 Feb 2024 04:11:45 -0500 Subject: [PATCH] explicitly set input low and high in test_ops (#3347) easier to set `(low, high)` than figuring out a,b for `(x+a)*b`. this pr kept the same input ranges --- test/test_ops.py | 73 ++++++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 33 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 270cf7690a..81143dec4f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -10,10 +10,11 @@ if CI: FORWARD_ONLY = getenv("FORWARD_ONLY", 0) PRINT_TENSORS = getenv("PRINT_TENSORS", 0) + def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, - forward_only=False, vals=None, a=-0.5, b=3): + forward_only=False, vals=None, low=-1.5, high=1.5): if tinygrad_fxn is None: tinygrad_fxn = torch_fxn - ts, tst = prepare_test_op(a, b, shps, vals, forward_only) + ts, tst = prepare_test_op(low, high, shps, vals, forward_only) st = time.monotonic() out = torch_fxn(*ts) @@ -59,19 +60,19 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % \ (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="") -def prepare_test_op(a, b, shps, vals, forward_only=False): +def prepare_test_op(low, high, shps, vals, forward_only=False): torch.manual_seed(0) np.random.seed(0) if shps is None: ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals] - else: ts = [torch.tensor((np.random.random(size=x) + a) * b, requires_grad=(not forward_only), dtype=torch.float32) for x in shps] + else: ts = [torch.tensor(np.random.uniform(low=low, high=high, size=x), requires_grad=(not forward_only), dtype=torch.float32) for x in shps] tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts] return ts, tst class TestOps(unittest.TestCase): - def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, exact=False, vals=None, a=-0.5, b=3): + def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, exact=False, vals=None, low=-1.5, high=1.5): if getenv("CUDACPU"): self.skipTest('helper_test_exception fails in CUDACPU') - ts, tst = prepare_test_op(a, b, shps, vals) + ts, tst = prepare_test_op(low, high, shps, vals) with self.assertRaises(expected) as torch_cm: torch_fxn(*ts) with self.assertRaises(expected) as tinygrad_cm: @@ -370,13 +371,13 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: x**2) helper_test_op([()], lambda x: x**-2) # Regression tests for https://github.com/tinygrad/tinygrad/issues/1151 - helper_test_op([(45,65)], lambda x: x**3, a=-10) - helper_test_op([()], lambda x: x**3, a=-10) + helper_test_op([(45,65)], lambda x: x**3, low=-30, high=-27) + helper_test_op([()], lambda x: x**3, low=-30, high=-27) # Regression tests for https://github.com/tinygrad/tinygrad/issues/1251 - helper_test_op([(45,65)], lambda x: x**0.2, a=-10) - helper_test_op([(45,65)], lambda x: x**1.2, a=-10) - helper_test_op([()], lambda x: x**0.2, a=-10) - helper_test_op([()], lambda x: x**1.2, a=-10) + helper_test_op([(45,65)], lambda x: x**0.2, low=-30, high=-27) + helper_test_op([(45,65)], lambda x: x**1.2, low=-30, high=-27) + helper_test_op([()], lambda x: x**0.2, low=-30, high=-27) + helper_test_op([()], lambda x: x**1.2, low=-30, high=-27) a, b = Tensor([0.0], requires_grad=True), torch.tensor([0.0], requires_grad=True) helper_test_op([], lambda: b**1.1, lambda: a**1.1) def test_pow_const(self): @@ -458,8 +459,8 @@ class TestOps(unittest.TestCase): def test_sigmoid(self): helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid) - helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, a=100) - helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, a=-100) + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=303) + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-300, high=-297) helper_test_op([()], torch.sigmoid, Tensor.sigmoid) def test_softplus(self): helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, atol=1e-6, grad_atol=1e-6) @@ -468,12 +469,12 @@ class TestOps(unittest.TestCase): def test_gelu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu) if not (CI and Device.DEFAULT == "METAL"): - helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=100) - helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=-100) + helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=300, high=303) + helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=-300, high=-297) def test_quick_gelu(self): helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) - helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=100) - helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=-100) + helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=300, high=303) + helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=-300, high=-297) helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) def test_elu(self): @@ -750,17 +751,17 @@ class TestOps(unittest.TestCase): def test_sinh(self): helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6) # TODO: backward nan instead of inf - helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6, a=-100, forward_only=True) - helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6, a=100, forward_only=True) + helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6, low=-300, high=-297, forward_only=True) + helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6, low=300, high=303, forward_only=True) def test_cosh(self): helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6) # TODO: backward nan instead of inf - helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6, a=-100, forward_only=True) - helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6, a=100, forward_only=True) + helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6, low=-300, high=-297, forward_only=True) + helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6, low=300, high=303, forward_only=True) def test_tanh(self): helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6) - helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, a=-100) - helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, a=100) + helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, low=300, high=303) def test_hardtanh(self): for val in range(10, 30, 5): helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) @@ -768,16 +769,16 @@ class TestOps(unittest.TestCase): def test_asinh(self): helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-6, grad_atol=1e-6) # NOTE: this one has larger atol - helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, a=-100) - helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-6, grad_atol=1e-6, a=100) + helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-6, grad_atol=1e-6, low=300, high=303) def test_acosh(self): helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6) - helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6, a=-100) - helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6, a=100) + helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6, low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6, low=300, high=303) def test_atanh(self): helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6) - helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6, a=-100) - helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6, a=100) + helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6, low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6, low=300, high=303) def test_topo_sort(self): helper_test_op([(45,65)], lambda x: (x+x)*x, atol=1e-6, grad_atol=1e-6) @@ -792,7 +793,10 @@ class TestOps(unittest.TestCase): (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]: with self.subTest(op=torch_op.__name__, shapes=shapes): - helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0) + if tinygrad_op != Tensor.pow: + helper_test_op(shapes, torch_op, tinygrad_op) + else: + helper_test_op(shapes, torch_op, tinygrad_op, low=0, high=3) def test_broadcast_simple(self): helper_test_op([(45,65), (45,1)], lambda x,y: x/y) @@ -800,12 +804,15 @@ class TestOps(unittest.TestCase): def test_broadcast_partial(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), - (torch.div, Tensor.div)]: #, (torch.pow, Tensor.pow)]: + (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: for shapes in [((1,32,32,32), (1,32,1,1)), ((5,13,24,16,2), (1,13,24,1,1)), ((4,1), (4,5)), ((1,4), (5,4))]: with self.subTest(op=torch_op.__name__, shapes=shapes): # NOTE: ANE backwards? - helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0) + if tinygrad_op != Tensor.pow: + helper_test_op(shapes, torch_op, tinygrad_op) + else: + helper_test_op(shapes, torch_op, tinygrad_op, low=0, high=3) def test_slice_in_bounds_1dim(self): helper_test_op([(3)], lambda x: x[1:3])