diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index a43f08d9b7..d0ed90a648 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -85,14 +85,6 @@ class TestSpeed(unittest.TestCase): def f(a, b): return a-b helper_test_generic_square('sub', 4096, f, f) - def test_constant_sub(self): - def f(a, b): return 1.0-a - helper_test_generic_square('sub', 4096, f, f) - - def test_constant_zero_sub(self): - def f(a, b): return 0.0-a - helper_test_generic_square('sub', 4096, f, f) - def test_pow(self): def f(a, b): return a.pow(b) helper_test_generic_square('pow', 2048, f, f) @@ -151,6 +143,14 @@ class TestSpeed(unittest.TestCase): def f(a, b): return a + b helper_test_generic_square('add', N, f, f) + def test_add_constant(self): + def f(a, b): return a+2.0 + helper_test_generic_square('add_constant', 4096, f, f) + + def test_add_constant_zero(self): + def f(a, b): return a+0.0 + helper_test_generic_square('add_constant_zero', 4096, f, f) + def test_add_sq(self): def f(a, b): return a*a + b*b helper_test_generic_square('add_sq', 4096, f, f)