mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
nah, no sign, it's not what you want. use relu
This commit is contained in:
@@ -6,12 +6,12 @@ import timeit
|
||||
import functools
|
||||
from tinygrad.tensor import Tensor, DEFAULT_DEVICE, Device
|
||||
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=0, rtol=1e-6, grad_atol=0, grad_rtol=1e-6, forward_only=False, vals=None):
|
||||
def helper_test_op(shps, torch_fxn, tinygrad_fxn, atol=1e-6, rtol=1e-3, grad_atol=1e-6, grad_rtol=1e-3, forward_only=False, vals=None):
|
||||
torch.manual_seed(0)
|
||||
if shps is None:
|
||||
ts = [torch.tensor(x, requires_grad=True) for x in vals]
|
||||
else:
|
||||
ts = [torch.rand(x, requires_grad=True) for x in shps]
|
||||
ts = [torch.tensor((np.random.random(size=x).astype(np.float32)-0.5)*20, requires_grad=True) for x in shps]
|
||||
|
||||
tst = [Tensor(x.detach().numpy()) for x in ts]
|
||||
out = torch_fxn(*ts)
|
||||
@@ -79,8 +79,8 @@ class TestOps(unittest.TestCase):
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot)
|
||||
def test_multidot(self):
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot)
|
||||
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
def test_sum(self):
|
||||
helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2)))
|
||||
@@ -159,7 +159,7 @@ class TestOps(unittest.TestCase):
|
||||
with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W):
|
||||
helper_test_op([(bs,cin,11,28), (6,cin//groups,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5)
|
||||
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
def test_strided_conv2d(self):
|
||||
bs = 4
|
||||
@@ -168,18 +168,19 @@ class TestOps(unittest.TestCase):
|
||||
with self.subTest(stride := 2):
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu())
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), atol=1e-4)
|
||||
with self.subTest(stride := (2,1)):
|
||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu())
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=1e-4)
|
||||
|
||||
def test_maxpool2d(self):
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([(32,2,110,28)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz))
|
||||
# TODO: why is this tolerance so high?
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz), grad_atol=1e-4)
|
||||
|
||||
def test_avgpool2d(self):
|
||||
shape = (32,2,111,28)
|
||||
|
||||
Reference in New Issue
Block a user