diff --git a/test/test_jit.py b/test/test_jit.py index 4a146f5cf9..30842ca3b2 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -585,7 +585,19 @@ class TestJitFree(unittest.TestCase): savings_after_free = pre_free - GlobalCounters.mem_used # Different allocator implementations have different savings. - self.assertEqual(savings_after_free, 8196 if hasattr(Device[Device.DEFAULT].allocator, '_offset') else 2024) + expected_savings = 8196 if hasattr(Device[Device.DEFAULT].allocator, '_offset') else 2024 + + self.assertEqual(savings_after_free, expected_savings) + out = fxn(Tensor([11,1,2,3,4])) + self.assertEqual(out.item(), 13600) + + # Try one more time... + pre_free = GlobalCounters.mem_used + fxn.captured.free_intermediates() + fxn.captured.free_intermediates() # 2nd time to validate + savings_after_free = pre_free - GlobalCounters.mem_used + + self.assertEqual(savings_after_free, expected_savings) out = fxn(Tensor([11,1,2,3,4])) self.assertEqual(out.item(), 13600)