widen test_ops [low, high] and more strict atol (#4906)

default [low, high] changed from [-1.5, 1.5] to [-2, 2] (except tan).
dropped several explicit atol if it's unnecessarily larger than default 1e-6.
tested on mac, tinybox red / green
This commit is contained in:
chenyu
2024-06-10 20:47:09 -04:00
committed by GitHub
parent 97b05f567e
commit 798ea61377

View File

@@ -12,7 +12,7 @@ FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
PRINT_TENSORS = getenv("PRINT_TENSORS", 0)
def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3,
forward_only=False, vals=None, low=-1.5, high=1.5):
forward_only=False, vals=None, low=-2, high=2):
if tinygrad_fxn is None: tinygrad_fxn = torch_fxn
ts, tst = prepare_test_op(low, high, shps, vals, forward_only)
@@ -475,7 +475,9 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.cos())
helper_test_op([()], lambda x: x.cos())
def test_tan(self):
helper_test_op([(45,65)], lambda x: x.tan())
# NOTE: backward has much higher diff with input close to pi/2 and -pi/2
helper_test_op([(45,65)], lambda x: x.tan(), low=-1.5, high=1.5)
helper_test_op([(45,65)], lambda x: x.tan(), low=-5, high=5, forward_only=True)
helper_test_op([()], lambda x: x.tan())
def test_relu(self):
@@ -527,8 +529,8 @@ class TestOps(unittest.TestCase):
def test_sigmoid(self):
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=303)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-300, high=-297)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=400)
helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-400, high=-300)
helper_test_op([()], torch.sigmoid, Tensor.sigmoid)
def test_softplus(self):
helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6)
@@ -536,12 +538,12 @@ class TestOps(unittest.TestCase):
def test_gelu(self):
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu)
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=300, high=303)
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=300, high=400)
helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=-400, high=-300)
def test_quick_gelu(self):
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=300, high=303)
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=-300, high=-297)
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=300, high=400)
helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=-400, high=-300)
helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu)
def test_elu(self):
@@ -672,17 +674,17 @@ class TestOps(unittest.TestCase):
@unittest.skipIf(IMAGE>0, "no 1d dot for images")
def test_dot_1d(self):
helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(65), (65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot)
helper_test_op([(65), (65,45)], lambda x,y: x.matmul(y), Tensor.dot)
helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot)
helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot)
helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot)
self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5)
helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5)
self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
with self.assertRaises(AssertionError):
@@ -710,28 +712,28 @@ class TestOps(unittest.TestCase):
)
def test_matmul_simple(self):
helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot)
def test_matmul(self):
helper_test_op([(64), (64,99)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(64), (64,99)], lambda x,y: x.matmul(y), Tensor.dot)
@unittest.skipIf(IMAGE>0, "no batched matmul on images")
def test_matmul_batched(self):
helper_test_op([(3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot)
@unittest.skipIf(IMAGE>0, "no batched matmul on images")
def test_matmul_batched_vector(self):
helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot)
def test_small_gemm(self):
helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3)
helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y)
def test_small_gemm_range(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.arange(0,64,dtype=np.float32).reshape(8,8),
np.arange(64,128,dtype=np.float32).reshape(8,8)])
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.arange(0,64,dtype=np.float32).reshape(8,8),
np.arange(64,128,dtype=np.float32).reshape(8,8)])
def test_small_gemm_eye(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
def test_gemm(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot)
def test_big_gemm(self):
helper_test_op([(256,256), (256,256)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
helper_test_op([(256,256), (256,256)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
@unittest.skipIf(IMAGE>0, "no 0 in shape matmul on images")
def test_gemm_with_zeros_shape(self):
helper_test_op([(8,8), (8,0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7)
@@ -742,14 +744,14 @@ class TestOps(unittest.TestCase):
helper_test_op([(0), (0,8)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7)
helper_test_op([(0), (0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7)
def test_broadcastdot(self):
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot)
with self.assertRaises(AssertionError):
a = Tensor(3.14)
b = Tensor.ones(3,3)
a @ b
def test_multidot(self):
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot)
helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot)
def test_sum_simple(self):
helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]])
@@ -1190,7 +1192,7 @@ class TestOps(unittest.TestCase):
# internal shape (1, 1, 512, 62, 62, 512, 3, 3) overflows a int
helper_test_op([(1,256,64,64), (512,256,3,3)],
lambda x,w: torch.nn.functional.conv2d(x, w),
lambda x,w: x.conv2d(w), atol=1e-2)
lambda x,w: x.conv2d(w), atol=1e-3)
@unittest.skip("slow")
def test_large_bs_conv(self):
@@ -1198,116 +1200,116 @@ class TestOps(unittest.TestCase):
# (or cause the conv kernel to overflow short sampling coords)
helper_test_op([(4096,3,3,3), (1,3,3,3)],
lambda x,w: torch.nn.functional.conv2d(x, w),
lambda x,w: x.conv2d(w), atol=1e-4, rtol=1e-2)
lambda x,w: x.conv2d(w), atol=1e-3)
@unittest.skip("slow")
def test_large_ic_conv(self):
# large input channel count can cause OpenCL image to exceed max image width on macOS
helper_test_op([(1,2048,3,3), (1,2048,3,3)],
lambda x,w: torch.nn.functional.conv2d(x, w),
lambda x,w: x.conv2d(w), atol=1e-4)
lambda x,w: x.conv2d(w))
def test_biased_conv2d(self):
C = 8
helper_test_op([(1,C,5,5), (C,C,1,1), (C,)],
lambda x,w,b: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w,b).relu(),w,b),
lambda x,w,b: Tensor.conv2d(x,w,b).relu().conv2d(w,b), atol=1e-4)
lambda x,w,b: Tensor.conv2d(x,w,b).relu().conv2d(w,b))
def test_simple_conv2d(self):
helper_test_op([(1,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w).relu(), grad_rtol=1e-5)
@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_simple_conv3d(self):
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
lambda x,w: torch.nn.functional.conv3d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w).relu(), grad_rtol=1e-5)
@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_padded_conv3d(self):
helper_test_op([(1,4,5,5,5), (4,4,3,3,3)],
lambda x,w: torch.nn.functional.conv3d(x,w,padding=1).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]).relu(), grad_rtol=1e-5)
def test_simple_conv2d_m4(self):
helper_test_op([(1,16,18,18), (16,16,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w).relu(), grad_rtol=1e-5)
def test_simple_conv2d_1x1(self):
helper_test_op([(1,4,9,9), (4,4,1,1)],
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w).relu(), grad_rtol=1e-5)
def test_simple_conv2d_1x1_m4(self):
helper_test_op([(1,16,32,32), (16,16,1,1)],
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w).relu(), grad_rtol=1e-5)
def test_nested_conv2d(self):
helper_test_op([(1,32,9,9), (32,32,3,3), (32,32,3,3)],
lambda x,w1,w2: torch.nn.functional.conv2d(torch.nn.functional.conv2d(x,w1).relu(), w2).relu(),
lambda x,w1,w2: x.conv2d(w1).relu().conv2d(w2).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w1,w2: x.conv2d(w1).relu().conv2d(w2).relu())
# expect reduce nodes == 3
def test_simple_conv2d_nhwc(self):
# weights (from tf): filter_height x filter_width x in_channels x out_channels
helper_test_op([(2,9,9,10), (3,3,10,20)],
lambda x,w: torch.nn.functional.conv2d(x.permute(0,3,1,2),w.permute(3,2,0,1)).relu(),
lambda x,w: Tensor.conv2d(x.permute(0,3,1,2),w.permute(3,2,0,1)).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x.permute(0,3,1,2),w.permute(3,2,0,1)).relu(), atol=1e-5, grad_rtol=1e-5)
def test_simple_conv2d_batched(self):
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w).relu(), grad_rtol=1e-5)
# conv transpose
def test_simple_conv_transpose2d(self):
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), grad_rtol=1e-5)
def test_bias_conv_transpose2d(self):
helper_test_op([(2,4,9,9), (4,4,3,3), (4,)],
lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b).relu(),
lambda x,w,b: Tensor.conv_transpose2d(x,w,b).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w,b: Tensor.conv_transpose2d(x,w,b).relu(), grad_rtol=1e-5)
def test_grouped_conv_transpose2d(self):
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,groups=2).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w,groups=2).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv_transpose2d(x,w,groups=2).relu(), grad_rtol=1e-5)
def test_padded_conv_transpose2d(self):
for padding in [(1,2), (2,1), 2, 1, 0]:
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,padding=padding).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w,padding=padding).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv_transpose2d(x,w,padding=padding).relu(), grad_rtol=1e-5)
def test_dilated_conv_transpose2d(self):
for dilation in [(1,2), (2,1), 2, 1]:
helper_test_op([(2,4,9,9), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w,dilation=dilation).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w,dilation=dilation).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv_transpose2d(x,w,dilation=dilation).relu(), grad_rtol=1e-5)
def test_strided_conv_transpose2d(self):
for stride in [(2,1), (1,2), 1]:
helper_test_op([(2,4,4,5), (4,4,3,3)],
lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), grad_rtol=1e-5)
def test_output_padded_conv_transpose2d(self):
for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]:
helper_test_op([(2,4,6,5), (4,4,3,3),(4,)],
lambda x,w,b: torch.nn.functional.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride).relu(),
lambda x,w,b: Tensor.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w,b: Tensor.conv_transpose2d(x,w,b,output_padding=output_padding,stride=stride).relu(), grad_rtol=1e-5)
@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_simple_conv_transpose3d(self):
helper_test_op([(2,4,9,9,9), (4,4,3,3,3)],
lambda x,w: torch.nn.functional.conv_transpose3d(x,w).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), grad_rtol=1e-5)
@unittest.skipIf((IMAGE>0), "no conv1d on images")
def test_conv1d(self):
@@ -1318,7 +1320,7 @@ class TestOps(unittest.TestCase):
with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H):
helper_test_op([(bs,cin,11), (6,cin//groups,H)],
lambda x,w: torch.nn.functional.conv1d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5)
@unittest.skipIf(IMAGE>0, "no conv1d on images")
def test_simple_padding_conv1d(self):
@@ -1329,14 +1331,14 @@ class TestOps(unittest.TestCase):
p = (1,1)
helper_test_op([(bs,cin,11), (6,cin//groups,H)],
lambda x,w: torch.nn.functional.conv1d(torch.nn.functional.pad(x, p),w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=p).relu())
@unittest.skipIf(IMAGE>0, "no conv1d on images")
def test_strided_conv1d_simple(self):
bs, H = 2, 3
helper_test_op([(bs,1,5), (1,1,H)],
lambda x,w: torch.nn.functional.conv1d(x,w,stride=2).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,stride=2).relu())
@unittest.skipIf(IMAGE>0, "no conv1d on images")
def test_asymmetric_padding_conv1d(self):
@@ -1346,10 +1348,10 @@ class TestOps(unittest.TestCase):
for k in [2]:
helper_test_op([(1,1,n), (1,1,k)],
lambda x,w: torch.nn.functional.conv1d(torch.nn.functional.pad(x, p),w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=p).relu())
helper_test_op([(1,1,n), (1,1,k)],
lambda x,w: torch.nn.functional.conv1d(torch.nn.functional.pad(x, p),w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=p).relu())
def _test_conv2d(self, bs=1, cin=1):
for H in [1,2,3]:
@@ -1358,7 +1360,7 @@ class TestOps(unittest.TestCase):
with self.subTest(batch_size=bs, channels=cin, groups=groups, height=H, width=W):
helper_test_op([(bs,cin,11,7), (6,cin//groups,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5)
def test_conv2d(self): self._test_conv2d(bs=1, cin=3)
def test_conv2d_bs_4_cin_3(self): self._test_conv2d(bs=4, cin=3)
def test_conv2d_bs_1_cin_1(self): self._test_conv2d(bs=1, cin=1)
@@ -1373,7 +1375,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(bs,cin,64,64), (6,cin//groups,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
# needed to relax tolerance on NVIDIA
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-3, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
def test_simple_grouped_conv2d(self):
bs = 1
@@ -1382,7 +1384,7 @@ class TestOps(unittest.TestCase):
cin = 2
helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5)
def test_medium_grouped_conv2d(self):
bs = 1
@@ -1391,7 +1393,7 @@ class TestOps(unittest.TestCase):
cin = 2
helper_test_op([(bs,groups*cin,1,1), (groups*rcout,cin,1,1)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5)
def test_depthwise_conv2d(self):
bs = 1
@@ -1400,7 +1402,7 @@ class TestOps(unittest.TestCase):
cin = 1
helper_test_op([(bs,groups*cin,32,32), (groups*rcout,cin,1,1)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5)
def test_grouped_conv2d(self):
bs = 4
@@ -1409,7 +1411,7 @@ class TestOps(unittest.TestCase):
cin = 3
helper_test_op([(bs,groups*cin,5,5), (groups*rcout,cin,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5)
def test_fancy_conv2d(self):
bs = 2
@@ -1419,13 +1421,13 @@ class TestOps(unittest.TestCase):
H,W = 3,3
helper_test_op([(bs,cin,11,28), (groups*cout,cin//groups,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,groups=groups).relu(),
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), atol=1e-4, grad_rtol=1e-5)
lambda x,w: Tensor.conv2d(x,w,groups=groups).relu(), grad_rtol=1e-5)
def test_strided_conv2d_simple(self):
bs,H,W = 2,3,1
helper_test_op([(bs,1,5,1), (1,1,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=2).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,stride=2).relu())
def test_strided_conv2d(self):
bs = 4
@@ -1434,26 +1436,26 @@ class TestOps(unittest.TestCase):
with self.subTest(stride := 2):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,stride=stride).relu())
with self.subTest(stride := (2,1)):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,stride=stride).relu(),
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,stride=(2,1)).relu())
def test_negative_padding_conv2d(self):
n,k = 10, 3
helper_test_op([(1,1,n,n), (1,1,k,k)],
lambda x,w: torch.nn.functional.conv2d(x[:, :, 1:-1, 1:-1],w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=-1).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=-1).relu())
helper_test_op([(1,1,n,n), (1,1,k,k)],
lambda x,w: torch.nn.functional.conv2d(x[:, :, 1:, 1:],w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=(-1,0,-1,0)).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=(-1,0,-1,0)).relu())
def test_simple_padding_conv2d(self):
p = (1,1,1,1)
helper_test_op(None,
lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4, vals=[[[[[2.,3.]]]], [[[[1.]]]]])
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), vals=[[[[[2.,3.]]]], [[[[1.]]]]])
def test_asymmetric_padding_conv2d(self):
for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]:
@@ -1462,34 +1464,34 @@ class TestOps(unittest.TestCase):
for k in [2]:
helper_test_op([(1,1,n,n), (1,1,k,k)],
lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=p).relu())
helper_test_op([(1,1,n,n), (1,1,k,k)],
lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=p).relu())
def test_padded_conv2d_p21(self):
bs,cin,H,W,padding = 4, 3, 3, 3, (2,1)
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu())
def test_padded_conv2d_p22(self):
bs,cin,H,W,padding = 4, 3, 3, 3, (2,2)
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu())
def test_padded_conv2d_1x1(self):
bs,cin,H,W,padding = 4, 3, 1, 1, 2
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu())
def test_padded_conv2d_bs1(self):
bs,cin,H,W,padding = 1, 3, 3, 3, 1
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu())
def test_padding_add(self):
helper_test_op([(64,64), (60,60)],
@@ -1504,7 +1506,7 @@ class TestOps(unittest.TestCase):
with self.subTest(dilation := d):
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(),
lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu(), atol=1e-4)
lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu())
def test_maxpool2d_simple(self):
ksz = (2,2)
@@ -1621,10 +1623,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.clip(3, 0))
def test_matvecmat(self):
helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z, atol=1e-4)
helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z)
def test_matvec(self):
helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu(), atol=1e-4)
helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu())
# this was the failure in llama early realizing freqs_cis
def test_double_slice(self):