From 2d3c7e4d4e7efa24fb18a3ec5825fab1fd48e06c Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 1 Aug 2024 12:39:59 -0700 Subject: [PATCH] some TestPickleJIT tests (#5860) * some TestPickleJIT tests * hotfix: print which opencl device we are using --- test/test_pickle.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/test_pickle.py b/test/test_pickle.py index 75ae39b2e1..7ab038b989 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -65,5 +65,28 @@ class TestPickle(unittest.TestCase): sched_pk = pickle.loads(pk) assert sched_pk[-1].ast == sched[-1].ast +class TestPickleJIT(unittest.TestCase): + @classmethod + def setUpClass(cls): + @TinyJit + def add(a, b): return a.sum()+b+1 + for _ in range(3): add(Tensor.rand(1000, 1000), Tensor.rand(1000, 1000)) + cls.st = pickle.dumps(add) + del add + + def test_inspect(self): + import io + class FakeClass: + def __init__(self, *args, **kwargs): + print(self.module, self.name) + class InspectUnpickler(pickle.Unpickler): + def find_class(self, module, name): return type("SpecializedFakeClass", (FakeClass,), {"name": name, "module": module}) + InspectUnpickler(io.BytesIO(self.st)).load() + + @unittest.skip("we are still saving intermediate buffers") + def test_size(self): + # confirm no intermediate buffers are saved + self.assertLess(len(self.st), 1_000_000) + if __name__ == '__main__': unittest.main()