mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
tests should use rtol unless special case (#82)
This commit is contained in:
@@ -5,7 +5,7 @@ import timeit
|
||||
import functools
|
||||
from tinygrad.tensor import Tensor, GPU
|
||||
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7, gpu=False, forward_only=False):
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, gpu=False, forward_only=False):
|
||||
ts = [torch.rand(x, requires_grad=True) for x in shps]
|
||||
tst = [Tensor(x.detach().numpy()) for x in ts]
|
||||
if gpu:
|
||||
@@ -14,14 +14,14 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-7, grad_atol=1e-7, gpu
|
||||
out = torch_fxn(*ts)
|
||||
ret = tinygrad_fxn(*tst)
|
||||
|
||||
np.testing.assert_allclose(ret.cpu().data, out.detach().numpy(), atol=atol)
|
||||
np.testing.assert_allclose(ret.cpu().data, out.detach().numpy(), atol=atol, rtol=rtol)
|
||||
|
||||
if not forward_only:
|
||||
out.mean().backward()
|
||||
ret.mean().backward()
|
||||
|
||||
for t, tt in zip(ts, tst):
|
||||
np.testing.assert_allclose(t.grad, tt.grad.cpu().data, atol=grad_atol)
|
||||
np.testing.assert_allclose(t.grad, tt.grad.cpu().data, atol=grad_atol, rtol=grad_rtol)
|
||||
|
||||
# speed
|
||||
torch_fp = timeit.Timer(functools.partial(torch_fxn, *ts)).timeit(5) * 1000/5
|
||||
@@ -46,8 +46,7 @@ class TestOps(unittest.TestCase):
|
||||
def test_mul(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x*y, Tensor.mul, gpu=self.gpu)
|
||||
def test_div(self):
|
||||
# TODO: why does this need more tolerance?
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, atol=1e-3, grad_atol=1e-3, gpu=self.gpu)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div, gpu=self.gpu)
|
||||
def test_pow(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, gpu=self.gpu)
|
||||
def test_sqrt(self):
|
||||
@@ -57,11 +56,11 @@ class TestOps(unittest.TestCase):
|
||||
def test_sigmoid(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, gpu=self.gpu)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5, gpu=self.gpu)
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, gpu=self.gpu)
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, atol=1e-4, gpu=self.gpu)
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum, gpu=self.gpu)
|
||||
def test_logsoftmax(self):
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-5, gpu=self.gpu)
|
||||
helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.logsoftmax, atol=1e-7, grad_atol=1e-7, gpu=self.gpu)
|
||||
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,1,1,1)), lambda x: x.pad2d(padding=(1,1,1,1)), gpu=self.gpu)
|
||||
@@ -74,7 +73,7 @@ class TestOps(unittest.TestCase):
|
||||
for W in [1,2,3,5]:
|
||||
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=2e-5, grad_atol=2e-6, gpu=self.gpu, forward_only=self.gpu)
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), gpu=self.gpu, grad_rtol=1e-5, forward_only=self.gpu)
|
||||
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
@@ -82,10 +81,10 @@ class TestOps(unittest.TestCase):
|
||||
H,W = 3,3
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), atol=2e-5, grad_atol=2e-6, gpu=self.gpu, forward_only=self.gpu)
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), gpu=self.gpu, forward_only=self.gpu)
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=(2,1)).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=2e-5, grad_atol=2e-6, gpu=self.gpu, forward_only=self.gpu)
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), gpu=self.gpu, forward_only=self.gpu)
|
||||
|
||||
def test_maxpool2d(self):
|
||||
# TODO merge into test_maxpool2d_strided when backward() is implemented
|
||||
|
||||
Reference in New Issue
Block a user