From 2da734920e45f437f80e7d74201450d5b8a7ebef Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 18 Feb 2024 10:28:37 -0500 Subject: [PATCH] use __getnewargs__ to fix unpickling Variable (#3441) it's recommended to use __getnewargs__ to update the args of classes that use __new__ when unpickling. It's preferred because it does not change the __new__ behavior. --- test/unit/test_symbolic.py | 11 +++-------- tinygrad/shape/symbolic.py | 3 ++- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index f13dff4fbb..ce5d26619a 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -3,14 +3,9 @@ import unittest, pickle from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer class TestSymbolicPickle(unittest.TestCase): - def test_pickle_variable(self): - dat = Variable("a", 3, 8) - datp = pickle.loads(pickle.dumps(dat)) - self.assertEqual(str(datp), "") - def test_pickle_variable_times_2(self): - dat = Variable("a", 3, 8)*2 - datp = pickle.loads(pickle.dumps(dat)) - self.assertEqual(str(datp), "<(a[3-8]*2)>") + def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x))) + def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8)) + def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2) class TestSymbolic(unittest.TestCase): def helper_test_variable(self, v, n, m, s): diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index adeb965a3a..b6cbf02f48 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -118,12 +118,13 @@ class Node: class Variable(Node): def __new__(cls, *args): - if len(args) == 0: return super().__new__(cls) # fix pickle expr, nmin, nmax = args assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}" if nmin == nmax: return NumNode(nmin) return super().__new__(cls) + def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling + def __init__(self, expr:str, nmin:int, nmax:int): self.expr, self.min, self.max = expr, nmin, nmax self._val: Optional[int] = None