support pickling tensors and others (#3787)

* test pickle tensors

* pickle unrealized tensor

* pickle jit, don't save Device in every CompiledASTRunner

* real test of pickle, move delete
This commit is contained in:
George Hotz
2024-03-17 18:29:14 -07:00
committed by GitHub
parent 5ac1fa933f
commit bf3e1c4df2
9 changed files with 70 additions and 24 deletions

View File

@@ -216,8 +216,8 @@ b2 = Buffer(dev, 9408, dtypes.float)
print(hex(b0._buf.value), hex(b0._buf.value+1605632*4))
print(hex(b1._buf.value))
print(hex(b2._buf.value))
#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", Device[dev], [7, 1, 1], [8, 4, 1], precompiled=lib)
prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", Device[dev], [49, 8, 2], [8, 4, 1], precompiled=lib)
#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib)
prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib)
print("compiled")
prg([b0, b1, b2], {})
print("ran")

View File

@@ -38,7 +38,7 @@ if __name__ == "__main__":
lin.linearize()
ptx_src = ptx.render(to_function_name(lin.name), lin.uops)
try:
ptx_prg = CompiledASTRunner(lin.name, ptx_src, dev, lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
ptx_prg = CompiledASTRunner(lin.name, ptx_src, "CUDA", lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
except RuntimeError:
print("PTX FAIL")
continue

View File

@@ -21,7 +21,7 @@ def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
}"""
CompiledASTRunner("atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b])
CompiledASTRunner("atan2_gpu", src, ret.device, global_size=[ret.size]).exec([ret, a, b])
def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data)

34
test/test_pickle.py Normal file
View File

@@ -0,0 +1,34 @@
import unittest, pickle
import numpy as np
from tinygrad import Tensor, TinyJit
class TestPickle(unittest.TestCase):
def test_pickle_realized_tensor(self):
t = Tensor.rand(10, 10).realize()
st = pickle.dumps(t)
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_unrealized_tensor(self):
t = Tensor.ones(10, 10)
st = pickle.dumps(t)
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_jit(self):
@TinyJit
def add(a, b): return a+b+1
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
#import dill
#with dill.detect.trace(): dill.dumps(add)
st = pickle.dumps(add)
add_fxn = pickle.loads(st)
x = Tensor.ones(10, 10).contiguous().realize()
y = Tensor.ones(10, 10).contiguous().realize()
print("post jit")
out = add_fxn(x, y)
np.testing.assert_equal(out.numpy(), 3)
if __name__ == '__main__':
unittest.main()

View File

@@ -13,7 +13,7 @@ from test.helpers import is_dtype_supported
def _uops_to_prg(uops):
src = Device[Device.DEFAULT].compiler.render("test", uops)
has_local = Device[Device.DEFAULT].compiler.linearizer_opts.has_local
return CompiledASTRunner("test", src, Device[Device.DEFAULT], [1] if has_local else None, [1] if has_local else None)
return CompiledASTRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None)
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg))