mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
hotfix: pickle jit works if you delete the function
This commit is contained in:
@@ -33,16 +33,15 @@ class TestPickle(unittest.TestCase):
|
|||||||
t2:Tensor = pickle.loads(st)
|
t2:Tensor = pickle.loads(st)
|
||||||
np.testing.assert_equal(t.numpy(), t2.numpy())
|
np.testing.assert_equal(t.numpy(), t2.numpy())
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_pickle_jit(self):
|
def test_pickle_jit(self):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def add(a, b): return a+b+1
|
def add(a, b): return a+b+1
|
||||||
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
|
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
|
||||||
#import dill
|
del add.fxn # pickling the JIT requires the function to be deleted
|
||||||
#with dill.detect.trace(): dill.dumps(add)
|
|
||||||
st = pickle.dumps(add)
|
st = pickle.dumps(add)
|
||||||
add_fxn = pickle.loads(st)
|
del add
|
||||||
|
|
||||||
|
add_fxn = pickle.loads(st)
|
||||||
x = Tensor.ones(10, 10).contiguous().realize()
|
x = Tensor.ones(10, 10).contiguous().realize()
|
||||||
y = Tensor.ones(10, 10).contiguous().realize()
|
y = Tensor.ones(10, 10).contiguous().realize()
|
||||||
print("post jit")
|
print("post jit")
|
||||||
|
|||||||
Reference in New Issue
Block a user