mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
use __getnewargs__ to fix unpickling Variable (#3441)
it's recommended to use __getnewargs__ to update the args of classes that use __new__ when unpickling. It's preferred because it does not change the __new__ behavior.
This commit is contained in:
@@ -3,14 +3,9 @@ import unittest, pickle
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer
|
||||
|
||||
class TestSymbolicPickle(unittest.TestCase):
|
||||
def test_pickle_variable(self):
|
||||
dat = Variable("a", 3, 8)
|
||||
datp = pickle.loads(pickle.dumps(dat))
|
||||
self.assertEqual(str(datp), "<a[3-8]>")
|
||||
def test_pickle_variable_times_2(self):
|
||||
dat = Variable("a", 3, 8)*2
|
||||
datp = pickle.loads(pickle.dumps(dat))
|
||||
self.assertEqual(str(datp), "<(a[3-8]*2)>")
|
||||
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
|
||||
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
|
||||
def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2)
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
|
||||
Reference in New Issue
Block a user