diff --git a/test/backend/test_pickle.py b/test/backend/test_pickle.py index dbf9dae402..3cc617d272 100644 --- a/test/backend/test_pickle.py +++ b/test/backend/test_pickle.py @@ -125,6 +125,13 @@ class TestPickle(unittest.TestCase): out = add_fxn(x, y) np.testing.assert_equal(out.numpy(), 102) + def test_pickle_jit_no_del(self): + @TinyJit + def fn(x): return x + 1.0 + for _ in range(3): fn(Tensor.randn(4)) + loaded = pickle.loads(pickle.dumps(fn)) + self.assertEqual(loaded(Tensor([1.0,2.0,3.0,4.0])).tolist(), [2.0,3.0,4.0,5.0]) + def test_pickle_context_var(self): v = ContextVar("test_var", 0) with Context(test_var=1):