Files
tinygrad/test/test_pickle.py
qazal 982086f54c UOps.VALID try 2 (#6623)
* make UOps.VALID compile

* fixable tests

* bufs dedup

* cleanup the CONST spec

* regenerate dataset with graph_rewrite

```py
def rewrite_const(const:UOp, st_src:UOp) -> UOp:
  st: ShapeTracker = st_src.arg
  return UOp(UOps.VALID, dtypes.bool, (st.to_uop(),)).where(UOp.const(const.dtype, const.arg), UOp.const(const.dtype, 0))
pm = PatternMatcher([(UPat(UOps.CONST, name="const", src=(UPat(UOps.SHAPETRACKER, name="st_src"),)), rewrite_const)])
```

* rm arg

* remove arg

* revert arg removal

This reverts commit 2c35c75c95.

* red test_pickle_define_var
2024-09-21 14:19:25 +08:00

125 lines
4.8 KiB
Python

import unittest, pickle
import numpy as np
from test.helpers import assert_equiv_uops
from tinygrad import Tensor, TinyJit, Variable
from tinygrad.codegen.kernel import Kernel
from tinygrad.dtype import PtrDType, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import BinaryOps, TernaryOps, UOp, UOps, UnaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
class TestPickle(unittest.TestCase):
def test_pickle_realized_tensor(self):
t = Tensor.rand(10, 10).realize()
st = pickle.dumps(t)
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_unrealized_tensor(self):
t = Tensor.ones(10, 10)
st = pickle.dumps(t)
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_variable(self):
v = Variable("i", 1, 20).bind(10)
t1 = Tensor.ones(10, v).contiguous()
t2 = Tensor.ones(10, v).contiguous()
ret = (t1+t2).sum(1)
st = pickle.dumps(ret)
del ret
vt2 = pickle.loads(st)
np.testing.assert_equal(vt2.numpy(), 20)
def test_pickle_buffer_view(self):
t = Tensor.arange(10, device="CLANG").contiguous().realize()
vt = t[3:5].contiguous().realize()
assert hasattr(vt.lazydata.buffer, 'base')
ref_value = vt.tolist()
st = pickle.dumps(vt)
del t, vt
vt2 = pickle.loads(st)
assert hasattr(vt2.lazydata.buffer, 'base')
assert ref_value == vt2.tolist()
def test_pickle_numpy(self):
t = Tensor(np.array([1,2,3,4.]))
st = pickle.dumps(t)
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_jit(self):
@TinyJit
def add(a, b): return a.sum()+b+1
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
st = pickle.dumps(add)
del add
add_fxn = pickle.loads(st)
x = Tensor.ones(10, 10).contiguous().realize()
y = Tensor.ones(10, 10).contiguous().realize()
print("post jit")
out = add_fxn(x, y)
np.testing.assert_equal(out.numpy(), 102)
def test_pickle_schedule(self):
a = Tensor([1,2])
out = a + 2
sched = create_schedule([out.lazydata])
pk = pickle.dumps(sched)
sched_pk = pickle.loads(pk)
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
def test_pickle_define_var(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
x2:=UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(Variable('i', 1, 10), 3), strides=(3, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), # noqa: E501
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
x12:=UOp(UOps.VALID, dtypes.bool, arg=None, src=(
x2,)),
UOp.define_var("i", dtypes.int, 1, 10),
x14:=UOp(UOps.CONST, dtypes.int, arg=0, src=()),)),
UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
x12,
UOp(UOps.CONST, dtypes.int, arg=3, src=()),
x14,)),)),)),)),)),)),))
p = Kernel(ast).to_program(name_override="test")
ps = Kernel(pickle.loads(pickle.dumps(ast))).to_program(name_override="test")
self.assertEqual(ps.src, p.src)
class TestPickleJIT(unittest.TestCase):
@classmethod
def setUpClass(cls):
@TinyJit
def add(a, b): return a.sum()+b+1
for _ in range(3): add(Tensor.rand(1000, 1000), Tensor.rand(1000, 1000))
cls.st = pickle.dumps(add)
del add
def test_inspect(self):
import io
class FakeClass:
def __init__(self, *args, **kwargs):
print(self.module, self.name)
class InspectUnpickler(pickle.Unpickler):
def find_class(self, module, name): return type("SpecializedFakeClass", (FakeClass,), {"name": name, "module": module})
InspectUnpickler(io.BytesIO(self.st)).load()
@unittest.skip("we are still saving intermediate buffers")
def test_size(self):
# confirm no intermediate buffers are saved
self.assertLess(len(self.st), 1_000_000)
if __name__ == '__main__':
unittest.main()