explicitly set input low and high in test_ops (#3347)

easier to set `(low, high)` than figuring out a,b for `(x+a)*b`. this pr kept the same input ranges
This commit is contained in:
chenyu
2024-02-08 04:11:45 -05:00
committed by GitHub
parent d8ad9e5660
commit b110c4a7b8

View File

@@ -10,10 +10,11 @@ if CI:
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
PRINT_TENSORS = getenv("PRINT_TENSORS", 0)
def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3,
forward_only=False, vals=None, a=-0.5, b=3):
forward_only=False, vals=None, low=-1.5, high=1.5):
if tinygrad_fxn is None: tinygrad_fxn = torch_fxn
ts, tst = prepare_test_op(a, b, shps, vals, forward_only)
ts, tst = prepare_test_op(low, high, shps, vals, forward_only)
st = time.monotonic()
out = torch_fxn(*ts)
@@ -59,19 +60,19 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % \
(shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="")
def prepare_test_op(a, b, shps, vals, forward_only=False):
def prepare_test_op(low, high, shps, vals, forward_only=False):
torch.manual_seed(0)
np.random.seed(0)
if shps is None: ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals]
else: ts = [torch.tensor((np.random.random(size=x) + a) * b, requires_grad=(not forward_only), dtype=torch.float32) for x in shps]
else: ts = [torch.tensor(np.random.uniform(low=low, high=high, size=x), requires_grad=(not forward_only), dtype=torch.float32) for x in shps]
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
return ts, tst
class TestOps(unittest.TestCase):
def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, exact=False, vals=None, a=-0.5, b=3):
def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, exact=False, vals=None, low=-1.5, high=1.5):
if getenv("CUDACPU"): self.skipTest('helper_test_exception fails in CUDACPU')
ts, tst = prepare_test_op(a, b, shps, vals)
ts, tst = prepare_test_op(low, high, shps, vals)
with self.assertRaises(expected) as torch_cm:
torch_fxn(*ts)
with self.assertRaises(expected) as tinygrad_cm:
@@ -370,13 +371,13 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: x**2)
helper_test_op([()], lambda x: x**-2)
# Regression tests for https://github.com/tinygrad/tinygrad/issues/1151
helper_test_op([(45,65)], lambda x: x**3, a=-10)
helper_test_op([()], lambda x: x**3, a=-10)
helper_test_op([(45,65)], lambda x: x**3, low=-30, high=-27)
helper_test_op([()], lambda x: x**3, low=-30, high=-27)
# Regression tests for https://github.com/tinygrad/tinygrad/issues/1251
helper_test_op([(45,65)], lambda x: x**0.2, a=-10)
helper_test_op([(45,65)], lambda x: x**1.2, a=-10)
helper_test_op([()], lambda x: x**0.2, a=-10)
helper_test_op([()], lambda x: x**1.2, a=-10)
helper_test_op([(45,65)], lambda x: x**0.2, low=-30, high=-27)
helper_test_op([(45,65)], lambda x: x**1.2, low=-30, high=-27)
helper_test_op([()], lambda x: x**0.2, low=-30, high=-27)
helper_test_op([()], lambda x: x**1.2, low=-30, high=-27)
a, b = Tensor([0.0], requires_grad=True), torch.tensor([0.0], requires_grad=True)
helper_test_op([], lambda: b**1.1, lambda: a**1.1)
def test_pow_const(self):
@@ -458,8 +459,8 @@ class TestOps(unittest.TestCase):
def test_sigmoid(self):
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, a=100)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, a=-100)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=303)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-300, high=-297)
helper_test_op([()], torch.sigmoid, Tensor.sigmoid)
def test_softplus(self):
helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, atol=1e-6, grad_atol=1e-6)
@@ -468,12 +469,12 @@ class TestOps(unittest.TestCase):
def test_gelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
if not (CI and Device.DEFAULT == "METAL"):
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=100)
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=-100)
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=300, high=303)
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=-300, high=-297)
def test_quick_gelu(self):
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=100)
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=-100)
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=300, high=303)
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=-300, high=-297)
helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
def test_elu(self):
@@ -750,17 +751,17 @@ class TestOps(unittest.TestCase):
def test_sinh(self):
helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6)
# TODO: backward nan instead of inf
helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6, a=-100, forward_only=True)
helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6, a=100, forward_only=True)
helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6, low=-300, high=-297, forward_only=True)
helper_test_op([(45,65)], lambda x: x.sinh(), atol=1e-6, grad_atol=1e-6, low=300, high=303, forward_only=True)
def test_cosh(self):
helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6)
# TODO: backward nan instead of inf
helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6, a=-100, forward_only=True)
helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6, a=100, forward_only=True)
helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6, low=-300, high=-297, forward_only=True)
helper_test_op([(45,65)], lambda x: x.cosh(), atol=1e-6, grad_atol=1e-6, low=300, high=303, forward_only=True)
def test_tanh(self):
helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6)
helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, a=-100)
helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, a=100)
helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.tanh(), atol=1e-6, grad_atol=1e-6, low=300, high=303)
def test_hardtanh(self):
for val in range(10, 30, 5):
helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6)
@@ -768,16 +769,16 @@ class TestOps(unittest.TestCase):
def test_asinh(self):
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-6, grad_atol=1e-6)
# NOTE: this one has larger atol
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, a=-100)
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-6, grad_atol=1e-6, a=100)
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-6, grad_atol=1e-6, low=300, high=303)
def test_acosh(self):
helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6)
helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6, a=-100)
helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6, a=100)
helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.acosh(), atol=1e-6, grad_atol=1e-6, low=300, high=303)
def test_atanh(self):
helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6)
helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6, a=-100)
helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6, a=100)
helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x.atanh(), atol=1e-6, grad_atol=1e-6, low=300, high=303)
def test_topo_sort(self):
helper_test_op([(45,65)], lambda x: (x+x)*x, atol=1e-6, grad_atol=1e-6)
@@ -792,7 +793,10 @@ class TestOps(unittest.TestCase):
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]:
with self.subTest(op=torch_op.__name__, shapes=shapes):
helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0)
if tinygrad_op != Tensor.pow:
helper_test_op(shapes, torch_op, tinygrad_op)
else:
helper_test_op(shapes, torch_op, tinygrad_op, low=0, high=3)
def test_broadcast_simple(self):
helper_test_op([(45,65), (45,1)], lambda x,y: x/y)
@@ -800,12 +804,15 @@ class TestOps(unittest.TestCase):
def test_broadcast_partial(self):
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
(torch.div, Tensor.div)]: #, (torch.pow, Tensor.pow)]:
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
for shapes in [((1,32,32,32), (1,32,1,1)), ((5,13,24,16,2), (1,13,24,1,1)),
((4,1), (4,5)), ((1,4), (5,4))]:
with self.subTest(op=torch_op.__name__, shapes=shapes):
# NOTE: ANE backwards?
helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0)
if tinygrad_op != Tensor.pow:
helper_test_op(shapes, torch_op, tinygrad_op)
else:
helper_test_op(shapes, torch_op, tinygrad_op, low=0, high=3)
def test_slice_in_bounds_1dim(self):
helper_test_op([(3)], lambda x: x[1:3])