import torch import time import numpy as np import unittest from tinygrad.tensor import Tensor from tinygrad.helpers import getenv FORWARD_ONLY = getenv("FORWARD_ONLY", 0) PRINT_TENSORS = getenv("PRINT_TENSORS", 0) def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3): if tinygrad_fxn is None: tinygrad_fxn = torch_fxn torch.manual_seed(0) np.random.seed(0) if shps is None: ts = [torch.tensor(x, requires_grad=True) for x in vals] else: ts = [torch.tensor((np.random.random(size=x).astype(np.float32)+a)*b, requires_grad=True) for x in shps] tst = [Tensor(x.detach().numpy(), requires_grad=not FORWARD_ONLY) for x in ts] st = time.monotonic() out = torch_fxn(*ts) torch_fp = time.monotonic() - st st = time.monotonic() ret = tinygrad_fxn(*tst).realize() tinygrad_fp = time.monotonic() - st def compare(s, x,y,atol,rtol): if PRINT_TENSORS: print(s, x, y) if y.shape != tuple(): assert x.shape == y.shape, f"shape mismatch (tinygrad){x.shape} != (torch){y.shape}" try: np.testing.assert_allclose(x,y, atol=atol, rtol=rtol) except Exception: raise Exception(f"{s} failed shape {x.shape}") compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol) torch_fbp, tinygrad_fbp = np.nan, np.nan if not forward_only and not FORWARD_ONLY: st = time.monotonic() (out+1).square().mean().backward() torch_fbp = time.monotonic() - st st = time.monotonic() (ret+1).square().mean().backward() for tt in tst: tt.grad.realize() tinygrad_fbp = time.monotonic() - st for i, (t, tt) in enumerate(zip(ts, tst)): compare(f"backward pass tensor {i}", tt.grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol) print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="") class TestOps(unittest.TestCase): def test_zeros(self): helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True) def test_ones(self): helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True) def test_eye(self): helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) def test_arange(self): helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) def _test_cmp(self, fxn, reverse=True): for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]: helper_test_op(shps, fxn, fxn, forward_only=True) helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]]) helper_test_op(None, lambda x,y: fxn(x,2), lambda x,y: fxn(x,2), forward_only=True, vals=[[0.,1,2], [2.,1,0]]) if reverse: helper_test_op(None, lambda x,y: fxn(2,y), lambda x,y: fxn(2,y), forward_only=True, vals=[[0.,1,2], [2.,1,0]]) def test_cmp_eq(self): self._test_cmp(lambda x,y: x==y, reverse=False) def test_cmp_gt(self): self._test_cmp(lambda x,y: x>y) def test_cmp_ge(self): self._test_cmp(lambda x,y: x>=y) def test_cmp_lt(self): self._test_cmp(lambda x,y: x