use __getnewargs__ to fix unpickling Variable (#3441)

it's recommended to use __getnewargs__ to update the args of classes that use __new__ when unpickling.
It's preferred because it does not change the __new__ behavior.
This commit is contained in:
chenyu
2024-02-18 10:28:37 -05:00
committed by GitHub
parent 5647148937
commit 2da734920e
2 changed files with 5 additions and 9 deletions

View File

@@ -3,14 +3,9 @@ import unittest, pickle
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer
class TestSymbolicPickle(unittest.TestCase):
def test_pickle_variable(self):
dat = Variable("a", 3, 8)
datp = pickle.loads(pickle.dumps(dat))
self.assertEqual(str(datp), "<a[3-8]>")
def test_pickle_variable_times_2(self):
dat = Variable("a", 3, 8)*2
datp = pickle.loads(pickle.dumps(dat))
self.assertEqual(str(datp), "<(a[3-8]*2)>")
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2)
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s):