mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
support pickling tensors and others (#3787)
* test pickle tensors * pickle unrealized tensor * pickle jit, don't save Device in every CompiledASTRunner * real test of pickle, move delete
This commit is contained in:
4
test/external/external_hip_compiler_bug.py
vendored
4
test/external/external_hip_compiler_bug.py
vendored
@@ -216,8 +216,8 @@ b2 = Buffer(dev, 9408, dtypes.float)
|
||||
print(hex(b0._buf.value), hex(b0._buf.value+1605632*4))
|
||||
print(hex(b1._buf.value))
|
||||
print(hex(b2._buf.value))
|
||||
#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", Device[dev], [7, 1, 1], [8, 4, 1], precompiled=lib)
|
||||
prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", Device[dev], [49, 8, 2], [8, 4, 1], precompiled=lib)
|
||||
#prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [7, 1, 1], [8, 4, 1], precompiled=lib)
|
||||
prg = CompiledASTRunner("r_2_8_7_7_4_8_3_7_7_4_4_2_2", "", dev, [49, 8, 2], [8, 4, 1], precompiled=lib)
|
||||
print("compiled")
|
||||
prg([b0, b1, b2], {})
|
||||
print("ran")
|
||||
|
||||
2
test/external/speed_compare_cuda_ptx.py
vendored
2
test/external/speed_compare_cuda_ptx.py
vendored
@@ -38,7 +38,7 @@ if __name__ == "__main__":
|
||||
lin.linearize()
|
||||
ptx_src = ptx.render(to_function_name(lin.name), lin.uops)
|
||||
try:
|
||||
ptx_prg = CompiledASTRunner(lin.name, ptx_src, dev, lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
|
||||
ptx_prg = CompiledASTRunner(lin.name, ptx_src, "CUDA", lin.global_size, lin.local_size, lin.uops.vars(), precompiled=ptx.compile(ptx_src))
|
||||
except RuntimeError:
|
||||
print("PTX FAIL")
|
||||
continue
|
||||
|
||||
@@ -21,7 +21,7 @@ def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):
|
||||
int idx = get_global_id(0);
|
||||
c[idx] = atan2(a[idx], b[idx]);
|
||||
}"""
|
||||
CompiledASTRunner("atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b])
|
||||
CompiledASTRunner("atan2_gpu", src, ret.device, global_size=[ret.size]).exec([ret, a, b])
|
||||
|
||||
def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data)
|
||||
|
||||
|
||||
34
test/test_pickle.py
Normal file
34
test/test_pickle.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import unittest, pickle
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
def test_pickle_realized_tensor(self):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
st = pickle.dumps(t)
|
||||
t2:Tensor = pickle.loads(st)
|
||||
np.testing.assert_equal(t.numpy(), t2.numpy())
|
||||
|
||||
def test_pickle_unrealized_tensor(self):
|
||||
t = Tensor.ones(10, 10)
|
||||
st = pickle.dumps(t)
|
||||
t2:Tensor = pickle.loads(st)
|
||||
np.testing.assert_equal(t.numpy(), t2.numpy())
|
||||
|
||||
def test_pickle_jit(self):
|
||||
@TinyJit
|
||||
def add(a, b): return a+b+1
|
||||
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
|
||||
#import dill
|
||||
#with dill.detect.trace(): dill.dumps(add)
|
||||
st = pickle.dumps(add)
|
||||
add_fxn = pickle.loads(st)
|
||||
|
||||
x = Tensor.ones(10, 10).contiguous().realize()
|
||||
y = Tensor.ones(10, 10).contiguous().realize()
|
||||
print("post jit")
|
||||
out = add_fxn(x, y)
|
||||
np.testing.assert_equal(out.numpy(), 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -13,7 +13,7 @@ from test.helpers import is_dtype_supported
|
||||
def _uops_to_prg(uops):
|
||||
src = Device[Device.DEFAULT].compiler.render("test", uops)
|
||||
has_local = Device[Device.DEFAULT].compiler.linearizer_opts.has_local
|
||||
return CompiledASTRunner("test", src, Device[Device.DEFAULT], [1] if has_local else None, [1] if has_local else None)
|
||||
return CompiledASTRunner("test", src, Device.DEFAULT, [1] if has_local else None, [1] if has_local else None)
|
||||
|
||||
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(vin), arg))
|
||||
|
||||
@@ -75,7 +75,7 @@ class BufferOptions:
|
||||
signal: bool = False
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None):
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:Optional[BufferOptions]=None, initial_value:Optional[bytes]=None):
|
||||
assert isinstance(dtype, DType)
|
||||
if isinstance(dtype, ImageDType): options = BufferOptions(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
||||
self.device, self.size, self.dtype, self.d, self.options = device, size, dtype, Device[device], options
|
||||
@@ -83,6 +83,11 @@ class Buffer:
|
||||
self._buf = opaque if opaque is not None else self.allocator.alloc(self.nbytes, options)
|
||||
# TODO: mem_used for all devices
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.nbytes
|
||||
if initial_value is not None: self.copyin(memoryview(initial_value))
|
||||
def __reduce__(self):
|
||||
buf = bytearray(self.nbytes)
|
||||
self.copyout(memoryview(buf))
|
||||
return self.__class__, (self.device, self.size, self.dtype, None, self.options, buf)
|
||||
@property
|
||||
def nbytes(self): return self.size*self.dtype.itemsize
|
||||
def __del__(self):
|
||||
@@ -187,20 +192,27 @@ class Compiler:
|
||||
return lib
|
||||
|
||||
class CompiledASTRunner(JITRunner):
|
||||
def __init__(self, name:str, prg:str, device:Compiled, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None,
|
||||
def __init__(self, name:str, prg:str, dname:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None,
|
||||
variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None):
|
||||
super().__init__()
|
||||
if DEBUG >= 4: print(prg)
|
||||
if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
|
||||
if local_size is not None: local_size = local_size + [1]*(3-len(local_size))
|
||||
self.name, self.display_name, self.prg, self.device, self.global_size, self.local_size, self.first_run = \
|
||||
to_function_name(name), name, prg, device, global_size, local_size, True
|
||||
assert self.device.compiler is not None, "compiler is reuired to make an AST kernel"
|
||||
self.name, self.display_name, self.prg, self.dname, self.global_size, self.local_size, self.first_run = \
|
||||
to_function_name(name), name, prg, dname, global_size, local_size, True
|
||||
assert self.device.compiler is not None, "compiler is required to make an AST kernel"
|
||||
lib:bytes = precompiled if precompiled is not None else self.device.compiler.compile_cached(prg)
|
||||
self.lib, self.clprg = lib, self.device.runtime(self.name, lib)
|
||||
self.vars: List[Variable] = [] if variables is None else variables
|
||||
self.op_estimate, self.mem_estimate = op_estimate, mem_estimate
|
||||
|
||||
@property
|
||||
def device(self): return Device[self.dname]
|
||||
|
||||
def __reduce__(self):
|
||||
return self.__class__, (self.name, self.prg, self.dname, self.global_size, self.local_size,
|
||||
self.vars, self.op_estimate, self.mem_estimate, self.lib)
|
||||
|
||||
def launch_dims(self, var_vals):
|
||||
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
|
||||
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
|
||||
@@ -218,7 +230,7 @@ class CompiledASTRunner(JITRunner):
|
||||
if local_size: lra['local_size'] = local_size
|
||||
et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
|
||||
if do_update_stats: update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit,
|
||||
lra=lra, device=self.device.dname, first_run=self.first_run)
|
||||
lra=lra, device=self.dname, first_run=self.first_run)
|
||||
self.first_run = False
|
||||
return et
|
||||
|
||||
@@ -238,7 +250,7 @@ class Compiled:
|
||||
ops, mem = k.uops.flops_mem()
|
||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||
# NOTE: we use min here to ignore the indexing FLOPS
|
||||
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size,
|
||||
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self.dname, k.global_size, k.local_size,
|
||||
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
return ret
|
||||
|
||||
|
||||
@@ -122,6 +122,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
for p in get_parameters(self.ret): p.realize()
|
||||
self.jit_cache = CacheCollector.finish()
|
||||
assert len(self.jit_cache) != 0, "didn't JIT anything!"
|
||||
del self.fxn
|
||||
if DEBUG >= 1 and len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) != len(input_rawbuffers):
|
||||
print("WARNING: some input tensors not found")
|
||||
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
@@ -162,7 +163,7 @@ class _CacheCollector:
|
||||
self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
|
||||
self.var_vals = var_vals if var_vals is not None else {}
|
||||
|
||||
def add(self, prg, rawbufs, var_vals):
|
||||
def add(self, prg, rawbufs:List[Buffer], var_vals:Dict[Variable, int]):
|
||||
if self.cache is None: return
|
||||
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def _time_program(variables:List[Variable], rdev:Compiled, lib:bytes, global_siz
|
||||
factor = 1
|
||||
if global_size is not None and max_global_size is not None:
|
||||
global_size, factor = _get_test_global_size(global_size, max_global_size, var_vals)
|
||||
try: car = CompiledASTRunner(name, "", rdev, global_size, local_size, variables=variables, precompiled=lib)
|
||||
try: car = CompiledASTRunner(name, "", rdev.dname, global_size, local_size, variables=variables, precompiled=lib)
|
||||
except AssertionError: return [math.inf] * cnt
|
||||
tms = []
|
||||
for _ in range(cnt):
|
||||
|
||||
@@ -1,30 +1,32 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
from typing import Union, Optional, Any, Tuple, List, Dict, cast
|
||||
from typing import Union, Optional, Any, Tuple, List
|
||||
from tinygrad.dtype import cast_scalar, dtypes, DType, Scalar
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
from weakref import ref, ReferenceType
|
||||
from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
|
||||
lazycache: Dict[Any, ReferenceType[LazyBuffer]] = {}
|
||||
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
||||
if st.size == 0 and op not in {LoadOps.SYNC, LoadOps.WAIT}: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
||||
if op is LoadOps.CONST: enable_cache = True
|
||||
|
||||
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
||||
if enable_cache and (rret := lazycache.get(cache_key, None)): return cast(LazyBuffer, rret()) # NOTE: this should always be a live reference
|
||||
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
|
||||
|
||||
return LazyBuffer(device, st, dtype, op, arg, srcs, base=base, cache_key=cache_key if enable_cache else None)
|
||||
ret = LazyBuffer(device, st, dtype, op, arg, srcs, base=base)
|
||||
if enable_cache: lazycache[cache_key] = ret
|
||||
return ret
|
||||
|
||||
class LazyBuffer:
|
||||
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
||||
op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, cache_key=None):
|
||||
self.device, self.st, self.dtype, self.shape, self.size, self.cache_key = device, st, dtype, st.shape, st.size, cache_key
|
||||
base:Optional[LazyBuffer]=None):
|
||||
self.device, self.st, self.dtype, self.shape, self.size = device, st, dtype, st.shape, st.size
|
||||
self._base: Optional[LazyBuffer] = None
|
||||
if base is None:
|
||||
# properties on base
|
||||
@@ -37,9 +39,6 @@ class LazyBuffer:
|
||||
# properties on view
|
||||
assert base.base == base, "base must be a base itself"
|
||||
self._base = base
|
||||
if cache_key is not None: lazycache[cache_key] = ref(self)
|
||||
|
||||
def __del__(self): lazycache.pop(self.cache_key, None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LB {self.device} {self.shape} contig:{self.st.contiguous} {self.st if self.base != self else (self.op, self.realized)}>"
|
||||
|
||||
Reference in New Issue
Block a user