diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index 0fb110088a..13c7da92f0 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -174,5 +174,56 @@ class TestOpt(unittest.TestCase): np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0).reshape(4,4,4,4), d.numpy(), rtol=1e-3) if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" + @unittest.skip("this is broken") + def test_no_binop_rerun(self): + a = Tensor.randn(16, 16) + b = Tensor.randn(16, 16) + with CLCache(): + c = a*b + d = (a*b).reshape(16, 16, 1) + c.realize() + d.realize() + assert len(GlobalCounters.cache) == 1, "binop was rerun!" + np.testing.assert_allclose(c.numpy(), d.numpy(), rtol=1e-3) + + @unittest.skip("this is broken") + def test_no_binop_rerun_alt(self): + a = Tensor.randn(16, 16) + b = Tensor.randn(16, 16) + with CLCache(): + c = (a*b).reshape(16, 16, 1) + d = a*b + c.realize() + d.realize() + assert len(GlobalCounters.cache) == 1, "binop was rerun!" + np.testing.assert_allclose(c.numpy(), d.numpy(), rtol=1e-3) + + # TODO: should be okay with PUSH_PERMUTES + def test_no_reduceop_rerun(self): + if PUSH_PERMUTES: return + a = Tensor.randn(16, 16, 16) + with CLCache(): + c = a.sum(2) + d = a.sum(2).permute(1,0) + c.realize() + d.realize() + cache_len = len(GlobalCounters.cache) + np.testing.assert_allclose(c.numpy().transpose(1,0), d.numpy()) + assert cache_len == 1, "reduceop was rerun!" + + # TODO: should be okay with PUSH_PERMUTES + def test_no_reduceop_rerun_alt(self): + if PUSH_PERMUTES: return + a = Tensor.randn(16, 16, 16) + with CLCache(): + c = a.sum(2).permute(1,0) + d = a.sum(2) + c.realize() + d.realize() + cache_len = len(GlobalCounters.cache) + np.testing.assert_allclose(c.numpy(), d.numpy().transpose(1,0)) + assert cache_len == 1, "reduceop was rerun!" + + if __name__ == '__main__': unittest.main()