hotfix: add test_simple_conv2d_bias

This commit is contained in:
George Hotz
2024-11-10 18:36:42 +08:00
parent 44c1fd5661
commit 745316493c

View File

@@ -1483,6 +1483,11 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), grad_rtol=1e-5)
def test_simple_conv2d_bias(self):
helper_test_op([(1,4,9,9), (4,4,3,3), (4,)],
lambda x,w,b: torch.nn.functional.conv2d(x,w,b).relu(),
lambda x,w,b: Tensor.conv2d(x,w,b).relu(), grad_rtol=1e-5)
@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_simple_conv3d(self):
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],