hotfix: pickle jit works if you delete the function

This commit is contained in:
George Hotz
2024-05-05 10:14:03 -07:00
parent 12be536c06
commit f95658bc3e

View File

@@ -33,16 +33,15 @@ class TestPickle(unittest.TestCase):
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
@unittest.expectedFailure
def test_pickle_jit(self):
@TinyJit
def add(a, b): return a+b+1
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
#import dill
#with dill.detect.trace(): dill.dumps(add)
del add.fxn # pickling the JIT requires the function to be deleted
st = pickle.dumps(add)
add_fxn = pickle.loads(st)
del add
add_fxn = pickle.loads(st)
x = Tensor.ones(10, 10).contiguous().realize()
y = Tensor.ones(10, 10).contiguous().realize()
print("post jit")