From 287de4ecc63b5d3dddaf27760a2455f215086c05 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 20 Feb 2025 12:26:11 -0500 Subject: [PATCH] use torch in test_gradient (#9186) used torch.autograd.grad, but not sure if it can be a template like jax --- setup.py | 1 - test/unit/test_gradient.py | 37 +++++++++++++++++++++---------------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/setup.py b/setup.py index ef5332021f..f0bd40a0de 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,6 @@ setup(name='tinygrad', 'testing': [ "numpy", "torch", - "jax", "pillow", "pytest", "pytest-xdist", diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index 76c58559de..777703a9a1 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -1,7 +1,6 @@ from typing import Callable import unittest, math -import jax -import jax.numpy as jnp +import torch from tinygrad import Tensor from tinygrad.dtype import dtypes from tinygrad.ops import UOp, Ops @@ -13,20 +12,22 @@ class TestGradient(unittest.TestCase): self.assertAlmostEqual(x, y, places=5) def _test_one_input_function(self, f:Callable, jf:Callable|None=None): + if jf is None: jf = f x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float) gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), set([x]))[x] - gf = jax.grad(f if jf is None else jf) for val in [-5., -2.0, 0.0, 2.0, 5.]: - tg_out, jax_out = gx.substitute({x: x.const_like(val)}).ssimplify(), gf(val).item() - self._cmp_nan_okay(tg_out, jax_out) + tg_out = gx.substitute({x: x.const_like(val)}).ssimplify() + tx = torch.tensor([val], dtype=torch.float, requires_grad=True) + torch_out = torch.autograd.grad(jf(tx), tx)[0].item() + self._cmp_nan_okay(tg_out, torch_out) def _test_two_input_function(self, f:Callable, jf:Callable|None=None): + if jf is None: jf = f x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float) y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float) grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), set([x, y])) gx, gy = grads[x], grads[y] - gf = jax.grad(f if jf is None else jf, argnums=(0, 1)) for valx in [-5., -2.0, 0.0, 2.0, 5.]: for valy in [-5., -2.0, 0.0, 2.0, 5.]: @@ -34,28 +35,32 @@ class TestGradient(unittest.TestCase): substitutions = {x: x.const_like(valx), y: y.const_like(valy)} tg_out_x = gx.substitute(substitutions).ssimplify() tg_out_y = gy.substitute(substitutions).ssimplify() - jax_out_x, jax_out_y = [x.item() for x in gf(valx, valy)] - self._cmp_nan_okay(tg_out_x, jax_out_x) - self._cmp_nan_okay(tg_out_y, jax_out_y) + tx = torch.tensor([valx], dtype=torch.float, requires_grad=True) + ty = torch.tensor([valy], dtype=torch.float, requires_grad=True) + torch_grad = torch.autograd.grad(jf(tx, ty), [tx, ty]) + torch_out_x, torch_out_y = [x.item() for x in torch_grad] + + self._cmp_nan_okay(tg_out_x, torch_out_x) + self._cmp_nan_okay(tg_out_y, torch_out_y) # unary ops unit def test_recip(self): self._test_one_input_function(lambda x: 1.0/x) - def test_sin(self): self._test_one_input_function(lambda x: x.sin(), lambda x: jnp.sin(x)) - def test_sqrt(self): self._test_one_input_function(lambda x: x.sqrt(), lambda x: jnp.sqrt(x)) - def test_log2(self): self._test_one_input_function(lambda x: x.log2(), lambda x: jnp.log2(x)) - def test_exp2(self): self._test_one_input_function(lambda x: x.exp2(), lambda x: jnp.exp2(x)) + def test_sin(self): self._test_one_input_function(lambda x: x.sin()) + def test_sqrt(self): self._test_one_input_function(lambda x: x.sqrt()) + def test_log2(self): self._test_one_input_function(lambda x: x.log2()) + def test_exp2(self): self._test_one_input_function(lambda x: x.exp2()) # binary ops unit def test_add(self): self._test_two_input_function(lambda x,y: x+y) def test_mul(self): self._test_two_input_function(lambda x,y: x*y) # chain rule - def test_chain(self): self._test_one_input_function(lambda x: x.sin().sqrt(), lambda x: jnp.sqrt(jnp.sin(x))) + def test_chain(self): self._test_one_input_function(lambda x: x.sin().sqrt()) def test_chain_binop(self): self._test_two_input_function(lambda x,y: (x*y)+x*y) - def test_big_add_sin(self): self._test_two_input_function(lambda x,y: x.sin()+3.0/y, lambda x,y: jnp.sin(x)+3.0/y) + def test_big_add_sin(self): self._test_two_input_function(lambda x,y: x.sin()+3.0/y) def test_big_chain(self): self._test_two_input_function(lambda x,y: (1.0/x*y)+x*y) - def test_where(self): self._test_two_input_function(lambda x,y: (x