mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
support pickling tensors and others (#3787)
* test pickle tensors * pickle unrealized tensor * pickle jit, don't save Device in every CompiledASTRunner * real test of pickle, move delete
This commit is contained in:
4
test/external/external_hip_compiler_bug.py
vendored
4
test/external/external_hip_compiler_bug.py
vendored
@@ -216,8 +216,8 @@ b2 = Buffer(dev, 9408, dtypes.float)
|
||||
print(hex(b0._buf.value), hex(b0._buf.value+1605632*4))
|
||||
print(hex(b1._buf.value))
|
||||
print(hex(b2._buf.value))
|
||||
#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", Device[dev], [7, 1, 1], [8, 4, 1], precompiled=lib)
|
||||
prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", Device[dev], [49, 8, 2], [8, 4, 1], precompiled=lib)
|
||||
#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib)
|
||||
prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib)
|
||||
print("compiled")
|
||||
prg([b0, b1, b2], {})
|
||||
print("ran")
|
||||
|
||||
2
test/external/speed_compare_cuda_ptx.py
vendored
2
test/external/speed_compare_cuda_ptx.py
vendored
@@ -38,7 +38,7 @@ if __name__ == "__main__":
|
||||
lin.linearize()
|
||||
ptx_src = ptx.render(to_function_name(lin.name), lin.uops)
|
||||
try:
|
||||
ptx_prg = CompiledASTRunner(lin.name, ptx_src, dev, lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
|
||||
ptx_prg = CompiledASTRunner(lin.name, ptx_src, "CUDA", lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
|
||||
except RuntimeError:
|
||||
print("PTX FAIL")
|
||||
continue
|
||||
|
||||
@@ -21,7 +21,7 @@ def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):
|
||||
int idx = get_global_id(0);
|
||||
c[idx] = atan2(a[idx], b[idx]);
|
||||
}"""
|
||||
CompiledASTRunner("atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b])
|
||||
CompiledASTRunner("atan2_gpu", src, ret.device, global_size=[ret.size]).exec([ret, a, b])
|
||||
|
||||
def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data)
|
||||
|
||||
|
||||
34
test/test_pickle.py
Normal file
34
test/test_pickle.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import unittest, pickle
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
def test_pickle_realized_tensor(self):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
st = pickle.dumps(t)
|
||||
t2:Tensor = pickle.loads(st)
|
||||
np.testing.assert_equal(t.numpy(), t2.numpy())
|
||||
|
||||
def test_pickle_unrealized_tensor(self):
|
||||
t = Tensor.ones(10, 10)
|
||||
st = pickle.dumps(t)
|
||||
t2:Tensor = pickle.loads(st)
|
||||
np.testing.assert_equal(t.numpy(), t2.numpy())
|
||||
|
||||
def test_pickle_jit(self):
|
||||
@TinyJit
|
||||
def add(a, b): return a+b+1
|
||||
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
|
||||
#import dill
|
||||
#with dill.detect.trace(): dill.dumps(add)
|
||||
st = pickle.dumps(add)
|
||||
add_fxn = pickle.loads(st)
|
||||
|
||||
x = Tensor.ones(10, 10).contiguous().realize()
|
||||
y = Tensor.ones(10, 10).contiguous().realize()
|
||||
print("post jit")
|
||||
out = add_fxn(x, y)
|
||||
np.testing.assert_equal(out.numpy(), 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -13,7 +13,7 @@ from test.helpers import is_dtype_supported
|
||||
def _uops_to_prg(uops):
|
||||
src = Device[Device.DEFAULT].compiler.render("test", uops)
|
||||
has_local = Device[Device.DEFAULT].compiler.linearizer_opts.has_local
|
||||
return CompiledASTRunner("test", src, Device[Device.DEFAULT], [1] if has_local else None, [1] if has_local else None)
|
||||
return CompiledASTRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None)
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(vin), arg))
|
||||
|
||||
Reference in New Issue
Block a user