support pickling tensors and others (#3787)

* test pickle tensors

* pickle unrealized tensor

* pickle jit, don't save Device in every CompiledASTRunner

* real test of pickle, move delete
This commit is contained in:
George Hotz
2024-03-17 18:29:14 -07:00
committed by GitHub
parent 5ac1fa933f
commit bf3e1c4df2
9 changed files with 70 additions and 24 deletions

View File

@@ -216,8 +216,8 @@ b2 = Buffer(dev, 9408, dtypes.float)
print(hex(b0._buf.value), hex(b0._buf.value+1605632*4))
print(hex(b1._buf.value))
print(hex(b2._buf.value))
#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", Device[dev], [7, 1, 1], [8, 4, 1], precompiled=lib)
prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", Device[dev], [49, 8, 2], [8, 4, 1], precompiled=lib)
#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib)
prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib)
print("compiled")
prg([b0, b1, b2], {})
print("ran")

View File

@@ -38,7 +38,7 @@ if __name__ == "__main__":
lin.linearize()
ptx_src = ptx.render(to_function_name(lin.name), lin.uops)
try:
ptx_prg = CompiledASTRunner(lin.name, ptx_src, dev, lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
ptx_prg = CompiledASTRunner(lin.name, ptx_src, "CUDA", lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
except RuntimeError:
print("PTX FAIL")
continue

View File

@@ -21,7 +21,7 @@ def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
}"""
CompiledASTRunner("atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b])
CompiledASTRunner("atan2_gpu", src, ret.device, global_size=[ret.size]).exec([ret, a, b])
def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data)

34
test/test_pickle.py Normal file
View File

@@ -0,0 +1,34 @@
import unittest, pickle
import numpy as np
from tinygrad import Tensor, TinyJit
class TestPickle(unittest.TestCase):
def test_pickle_realized_tensor(self):
t = Tensor.rand(10, 10).realize()
st = pickle.dumps(t)
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_unrealized_tensor(self):
t = Tensor.ones(10, 10)
st = pickle.dumps(t)
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_jit(self):
@TinyJit
def add(a, b): return a+b+1
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
#import dill
#with dill.detect.trace(): dill.dumps(add)
st = pickle.dumps(add)
add_fxn = pickle.loads(st)
x = Tensor.ones(10, 10).contiguous().realize()
y = Tensor.ones(10, 10).contiguous().realize()
print("post jit")
out = add_fxn(x, y)
np.testing.assert_equal(out.numpy(), 3)
if __name__ == '__main__':
unittest.main()

View File

@@ -13,7 +13,7 @@ from test.helpers import is_dtype_supported
def _uops_to_prg(uops):
src = Device[Device.DEFAULT].compiler.render("test", uops)
has_local = Device[Device.DEFAULT].compiler.linearizer_opts.has_local
return CompiledASTRunner("test", src, Device[Device.DEFAULT], [1] if has_local else None, [1] if has_local else None)
return CompiledASTRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None)
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg))

View File

@@ -75,7 +75,7 @@ class BufferOptions:
signal: bool = False
class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None):
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, initial_value:Optional[bytes]=None):
assert isinstance(dtype, DType)
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
self.device, self.size, self.dtype, self.d, self.options = device, size, dtype, Device[device], options
@@ -83,6 +83,11 @@ class Buffer:
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, options)
# TODO: mem_used for all devices
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
if initial_value is not None: self.copyin(memoryview(initial_value))
def __reduce__(self):
buf = bytearray(self.nbytes)
self.copyout(memoryview(buf))
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf)
@property
def nbytes(self): return self.size*self.dtype.itemsize
def __del__(self):
@@ -187,20 +192,27 @@ class Compiler:
return lib
class CompiledASTRunner(JITRunner):
def __init__(self, name:str, prg:str, device:Compiled, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None,
def __init__(self, name:str, prg:str, dname:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None,
variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None):
super().__init__()
if DEBUG >= 4: print(prg)
if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
if local_size is not None: local_size = local_size + [1]*(3-len(local_size))
self.name, self.display_name, self.prg, self.device, self.global_size, self.local_size, self.first_run = \
to_function_name(name), name, prg, device, global_size, local_size, True
assert self.device.compiler is not None, "compiler is reuired to make an AST kernel"
self.name, self.display_name, self.prg, self.dname, self.global_size, self.local_size, self.first_run = \
to_function_name(name), name, prg, dname, global_size, local_size, True
assert self.device.compiler is not None, "compiler is required to make an AST kernel"
lib:bytes = precompiled if precompiled is not None else self.device.compiler.compile_cached(prg)
self.lib, self.clprg = lib, self.device.runtime(self.name, lib)
self.vars: List[Variable] = [] if variables is None else variables
self.op_estimate, self.mem_estimate = op_estimate, mem_estimate
@property
def device(self): return Device[self.dname]
def __reduce__(self):
return self.__class__, (self.name, self.prg, self.dname, self.global_size, self.local_size,
self.vars, self.op_estimate, self.mem_estimate, self.lib)
def launch_dims(self, var_vals):
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
@@ -218,7 +230,7 @@ class CompiledASTRunner(JITRunner):
if local_size: lra['local_size'] = local_size
et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
if do_update_stats: update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit,
lra=lra, device=self.device.dname, first_run=self.first_run)
lra=lra, device=self.dname, first_run=self.first_run)
self.first_run = False
return et
@@ -238,7 +250,7 @@ class Compiled:
ops, mem = k.uops.flops_mem()
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size,
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size,
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
return ret

View File

@@ -122,6 +122,7 @@ class TinyJit(Generic[ReturnType]):
for p in get_parameters(self.ret): p.realize()
self.jit_cache = CacheCollector.finish()
assert len(self.jit_cache) != 0, "didn't JIT anything!"
del self.fxn
if DEBUG >= 1 and len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) != len(input_rawbuffers):
print("WARNING: some input tensors not found")
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
@@ -162,7 +163,7 @@ class _CacheCollector:
self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
self.var_vals = var_vals if var_vals is not None else {}
def add(self, prg, rawbufs, var_vals):
def add(self, prg, rawbufs:List[Buffer], var_vals:Dict[Variable, int]):
if self.cache is None: return
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"

View File

@@ -36,7 +36,7 @@ def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_siz
factor = 1
if global_size is not None and max_global_size is not None:
global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals)
try: car = CompiledASTRunner(name, "", rdev, global_size, local_size, variables=variables, precompiled=lib)
try: car = CompiledASTRunner(name, "", rdev.dname, global_size, local_size, variables=variables, precompiled=lib)
except AssertionError: return [math.inf] * cnt
tms = []
for _ in range(cnt):

View File

@@ -1,30 +1,32 @@
from __future__ import annotations
import math
from typing import Union, Optional, Any, Tuple, List, Dict, cast
from typing import Union, Optional, Any, Tuple, List
from tinygrad.dtype import cast_scalar, dtypes, DType, Scalar
from tinygrad.helpers import prod, getenv, all_int, all_same
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op
from tinygrad.shape.symbolic import sint
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
from weakref import ref, ReferenceType
from weakref import ref, ReferenceType, WeakValueDictionary
lazycache: Dict[Any, ReferenceType[LazyBuffer]] = {}
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
if st.size == 0 and op not in {LoadOps.SYNC, LoadOps.WAIT}: op, arg, srcs, base = LoadOps.CONST, 0, (), None
if op is LoadOps.CONST: enable_cache = True
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
if enable_cache and (rret := lazycache.get(cache_key, None)): return cast(LazyBuffer, rret()) # NOTE: this should always be a live reference
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
return LazyBuffer(device, st, dtype, op, arg, srcs, base=base, cache_key=cache_key if enable_cache else None)
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
if enable_cache: lazycache[cache_key] = ret
return ret
class LazyBuffer:
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, cache_key=None):
self.device, self.st, self.dtype, self.shape, self.size, self.cache_key = device, st, dtype, st.shape, st.size, cache_key
base:Optional[LazyBuffer]=None):
self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
self._base: Optional[LazyBuffer] = None
if base is None:
# properties on base
@@ -37,9 +39,6 @@ class LazyBuffer:
# properties on view
assert base.base == base, "base must be a base itself"
self._base = base
if cache_key is not None: lazycache[cache_key] = ref(self)
def __del__(self): lazycache.pop(self.cache_key, None)
def __repr__(self) -> str:
return f"<LB {self.device} {self.shape} contig:{self.st.contiguous} {self.st if self.base != self else (self.op, self.realized)}>"