From ebc94c9d6c622da8555fbfae79766c64a3a3c577 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 12 Apr 2024 21:54:36 -0700 Subject: [PATCH] rewrite the jit in the context of new schedule (#4162) * rewrite the jit in the context of new schedule * mypy better * fix placeholder * tests * all functionality should work * fix tests * no CacheCollector --- examples/benchmark_train_efficientnet.py | 66 ---------- test/external/external_test_opt.py | 89 +++++++------ test/test_linearizer.py | 10 +- tinygrad/device.py | 8 +- tinygrad/engine/jit.py | 158 ++++++++++------------- tinygrad/engine/realize.py | 10 +- tinygrad/lazy.py | 4 + 7 files changed, 134 insertions(+), 211 deletions(-) delete mode 100755 examples/benchmark_train_efficientnet.py diff --git a/examples/benchmark_train_efficientnet.py b/examples/benchmark_train_efficientnet.py deleted file mode 100755 index fb26e2893a..0000000000 --- a/examples/benchmark_train_efficientnet.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python3 -import gc -import time -from tqdm import trange -from extra.models.efficientnet import EfficientNet -from tinygrad.nn.state import get_parameters -from tinygrad.nn import optim -from tinygrad import Tensor, GlobalCounters -from tinygrad.helpers import getenv -from tinygrad.engine.jit import CacheCollector - -def tensors_allocated(): - return sum(isinstance(x, Tensor) for x in gc.get_objects()) - -NUM = getenv("NUM", 2) -BS = getenv("BS", 8) -CNT = getenv("CNT", 10) -BACKWARD = getenv("BACKWARD", 0) -TRAINING = getenv("TRAINING", 1) -ADAM = getenv("ADAM", 0) -CLCACHE = getenv("CLCACHE", 0) - -if __name__ == "__main__": - print(f"NUM:{NUM} BS:{BS} CNT:{CNT}") - model = EfficientNet(NUM, classes=1000, has_se=False, track_running_stats=False) - parameters = get_parameters(model) - for p in parameters: p.realize() - if ADAM: optimizer = optim.Adam(parameters, lr=0.001) - else: optimizer = optim.SGD(parameters, lr=0.001) - - Tensor.training = TRAINING - Tensor.no_grad = not BACKWARD - for i in trange(CNT): - GlobalCounters.reset() - cpy = time.monotonic() - x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize() - y_train = Tensor.randn(BS, 1000, requires_grad=False).realize() - - # TODO: replace with TinyJit - if i < 3 or not CLCACHE: - st = time.monotonic() - out = model.forward(x_train) - loss = out.log_softmax().mul(y_train).mean() - if i == 2 and CLCACHE: CacheCollector.start() - if BACKWARD: - optimizer.zero_grad() - loss.backward() - optimizer.step() - mt = time.monotonic() - loss.realize() - for p in parameters: - p.realize() - et = time.monotonic() - else: - st = mt = time.monotonic() - for prg, args in cl_cache: prg(*args) - et = time.monotonic() - - if i == 2 and CLCACHE: - cl_cache = CacheCollector.finish() - - mem_used = GlobalCounters.mem_used - loss_cpu = loss.detach().numpy() - cl = time.monotonic() - - print(f"{(st-cpy)*1000.0:7.2f} ms cpy, {(cl-st)*1000.0:7.2f} ms run, {(mt-st)*1000.0:7.2f} ms build, {(et-mt)*1000.0:7.2f} ms realize, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {tensors_allocated():4d} tensors, {mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index a9158a74ee..1107a8b25b 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -10,24 +10,27 @@ from tinygrad.helpers import getenv from tinygrad.nn import optim #from tinygrad.lazy import PUSH_PERMUTES PUSH_PERMUTES = False -from tinygrad.engine.jit import CacheCollector +from tinygrad.engine.realize import capturing class CLCache: def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None): self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {} + self.count = 0 + def add(self, ei): self.count += 1 def __enter__(self): if self.preclear: gc.collect() for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]: x.realize() GlobalCounters.reset() - CacheCollector.start(self.var_vals) + capturing.append(self) print("cache: entering") + return self def __exit__(self, type, value, traceback): - cache = CacheCollector.finish() - print(f"cache: exiting with size {len(cache)}", f"allowed {self.allowed}" if self.allowed is not None else "") + capturing.clear() + print(f"cache: exiting with size {self.count}", f"allowed {self.allowed}" if self.allowed is not None else "") if self.allowed is not None: - assert len(cache) <= self.allowed and (not self.strict or len(cache) == self.allowed), f"used too many kernels! {len(cache)} > {self.allowed}" + assert self.count <= self.allowed and (not self.strict or self.count == self.allowed), f"used too many kernels! {self.count} > {self.allowed}" from extra.models.convnext import ConvNeXt from extra.models.efficientnet import EfficientNet @@ -77,9 +80,9 @@ class TestInferenceMinKernels(unittest.TestCase): model = ViT(embed_dim=192, num_heads=3) for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) img = Tensor.randn(1, 3, 224, 224) - with CLCache(222): # NOTE: this is way too high + with CLCache(222) as cache: # NOTE: this is way too high out = model.forward(img) - assert len(CacheCollector.cache) == 0, "ViT prerealized?" + assert cache.count == 0, "ViT prerealized?" out.realize() @unittest.skip("llama is fp16 but CI does not have fp16") @@ -97,12 +100,12 @@ class TestOptBinOp(unittest.TestCase): def _test_no_binop_rerun(self, f1, f2=None, allowed=1): a = Tensor.randn(16, 16) b = Tensor.randn(16, 16) - with CLCache(): + with CLCache() as cache: c = f1(a, b) if f2 is not None: d = f2(a, b) c.realize() if f2 is not None: d.realize() - assert len(CacheCollector.cache) == allowed, "binop was rerun!" + assert cache.count == allowed, "binop was rerun!" if f2 is not None: np.testing.assert_allclose(c.numpy().ravel(), d.numpy().ravel(), rtol=1e-3, atol=1e-5) def test_no_binop_rerun(self): return self._test_no_binop_rerun(lambda a,b: a*b, lambda a,b: (a*b).reshape(16, 16, 1)) @@ -125,22 +128,22 @@ class TestOptReduceLoop(unittest.TestCase): def test_loop_left(self): a = Tensor.randn(16, 16) b = Tensor.randn(16, 16) - with CLCache(): + with CLCache() as cache: t = a.sum(0) b = t.reshape(16,1).expand(16,16).sum(0) c = (t+b) c.realize() - assert len(CacheCollector.cache) == 2, "loop left fusion broken" + assert cache.count == 2, "loop left fusion broken" def test_loop_right(self): a = Tensor.randn(16, 16) b = Tensor.randn(16, 16) - with CLCache(): + with CLCache() as cache: t = a.sum(0) b = t.reshape(16,1).expand(16,16).sum(0) c = (b+t) c.realize() - assert len(CacheCollector.cache) == 2, "loop right fusion broken" + assert cache.count == 2, "loop right fusion broken" @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOptWChild(unittest.TestCase): @@ -148,12 +151,12 @@ class TestOptWChild(unittest.TestCase): def test_unrealized_child(self): a = Tensor.randn(16, 16) b = Tensor.randn(16, 16) - with CLCache(): + with CLCache() as cache: c = (a*b).sum() d = c+1 e = c+2 # noqa: F841 d.realize() - assert len(CacheCollector.cache) == 2, "don't fuse if you have children" + assert cache.count == 2, "don't fuse if you have children" @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOpt(unittest.TestCase): @@ -168,34 +171,34 @@ class TestOpt(unittest.TestCase): def test_fold_reduce_elementwise(self): img = Tensor.ones(32).contiguous() addme = Tensor.ones(1) - with CLCache(): + with CLCache() as cache: ret = img.sum() + addme ret.realize() - assert len(CacheCollector.cache) == 1, "optimizer didn't fold reduce/elementwise" + assert cache.count == 1, "optimizer didn't fold reduce/elementwise" assert ret.item() == 33 def test_fold_batchnorm(self): with Tensor.train(): img = Tensor.ones(1,32,4,4).contiguous() bn = nn.BatchNorm2d(32, track_running_stats=False) - with CLCache(): + with CLCache() as cache: img_bn = bn(img).realize() print(img_bn) - assert len(CacheCollector.cache) == 3, f"optimizer didn't fold batchnorm, got {len(CacheCollector.cache)}" + assert cache.count == 3, f"optimizer didn't fold batchnorm, got {cache.count}" def test_fold_conv_sgd(self): with Tensor.train(): img = Tensor.ones(2,3,4,4) c1 = nn.Conv2d(3,32,3) opt = optim.SGD(get_parameters(c1)) - with CLCache(): + with CLCache() as cache: opt.zero_grad() c1(img).relu().sum().backward() opt.step() # TODO: this should be 4, but the sum output child stays around # with pushing_permutes it can be 3 # TODO: broken with optim fixes - assert len(CacheCollector.cache) in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {len(CacheCollector.cache)}" + assert cache.count in [4,5,6], f"optimizer didn't fold conv-backward SGD, got {cache.count}" def test_fold_2convs_sgd(self): with Tensor.train(): @@ -239,74 +242,74 @@ class TestOpt(unittest.TestCase): bn = nn.BatchNorm2d(32, track_running_stats=False) # precache the bn bn(c1(img)).relu().realize() - with CLCache(): + with CLCache() as cache: bn(c1(img)).relu().realize() - assert len(CacheCollector.cache) == 1, f"optimizer didn't fold conv-batchnorm at test time, got {len(CacheCollector.cache)}" + assert cache.count == 1, f"optimizer didn't fold conv-batchnorm at test time, got {cache.count}" def test_fold_conv_batchnorm(self): with Tensor.train(): img = Tensor.ones(1,3,8,8) c1 = nn.Conv2d(3,32,3) bn = nn.BatchNorm2d(32, track_running_stats=False) - with CLCache(): + with CLCache() as cache: img_conv = bn(c1(img)).relu().realize() print(img_conv) - assert len(CacheCollector.cache) == 4, f"optimizer didn't fold conv-batchnorm, got {len(CacheCollector.cache)}" + assert cache.count == 4, f"optimizer didn't fold conv-batchnorm, got {cache.count}" def test_fold_conv_elu(self): img = Tensor.ones(1,4,8,8) c1 = nn.Conv2d(4, 4, kernel_size=3) c2 = nn.Conv2d(4, 4, kernel_size=3) - with CLCache(): + with CLCache() as cache: img_conv = img.sequential([c1, Tensor.elu, c2, Tensor.elu]).realize() print(img_conv) - assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/elu" + assert cache.count == 2, "optimizer didn't fold conv/elu" def test_fold_conv_relu(self): img = Tensor.ones(1,4,8,8) c1 = nn.Conv2d(4, 4, kernel_size=3) c2 = nn.Conv2d(4, 4, kernel_size=3) - with CLCache(): + with CLCache() as cache: img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize() print(img_conv) - assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu" + assert cache.count == 2, "optimizer didn't fold conv/relu" def test_fold_conv_relu_nobias(self): img = Tensor.ones(1,4,8,8) c1 = nn.Conv2d(4, 4, kernel_size=3, bias=False) c2 = nn.Conv2d(4, 4, kernel_size=3, bias=False) - with CLCache(): + with CLCache() as cache: img_conv = img.sequential([c1, Tensor.relu, c2, Tensor.relu]).realize() print(img_conv) - assert len(CacheCollector.cache) == 2, "optimizer didn't fold conv/relu" + assert cache.count == 2, "optimizer didn't fold conv/relu" def test_permute_was_pushed(self): a = Tensor.randn(16, 16, 16) - with CLCache(2): + with CLCache(2) as cache: c = a.sum(2) d = c.permute(1,0).contiguous() d.realize() - cache_len = len(CacheCollector.cache) + cache_len = cache.count np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5) if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" def test_permute_was_pushed_through_contract_reshape(self): a = Tensor.randn(4, 4, 4, 4, 4) - with CLCache(2): + with CLCache(2) as cache: c = a.sum(-1) d = c.reshape(16,16).permute(1,0).contiguous() d.realize() - cache_len = len(CacheCollector.cache) + cache_len = cache.count np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,16).transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5) if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" def test_permute_was_pushed_through_contractw1s_reshape(self): a = Tensor.randn(4, 4, 4, 4, 4) - with CLCache(2): + with CLCache(2) as cache: c = a.sum(-1) d = c.reshape(16,1,16).permute(2,1,0).contiguous() d.realize() - cache_len = len(CacheCollector.cache) + cache_len = cache.count np.testing.assert_allclose(a.numpy().sum(-1).reshape(16,1,16).transpose(2,1,0), d.numpy(), rtol=1e-3, atol=1e-5) if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" @@ -315,35 +318,35 @@ class TestOpt(unittest.TestCase): @unittest.skipIf(not PUSH_PERMUTES, "this test requires PUSH_PERMUTES") def test_permute_was_pushed_through_expand_reshape(self): a = Tensor.randn(16, 16, 16) - with CLCache(): + with CLCache() as cache: c = a.sum(2) d = c.reshape(4,4,4,4).permute(2,3,0,1).contiguous() d.realize() - cache_len = len(CacheCollector.cache) + cache_len = cache.count np.testing.assert_allclose(a.numpy().sum(2).transpose(1,0).reshape(4,4,4,4), d.numpy(), rtol=1e-3, atol=1e-5) if PUSH_PERMUTES: assert cache_len == 1, "permute wasn't pushed!" @unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES") def test_no_reduceop_rerun(self): a = Tensor.randn(16, 16, 16) - with CLCache(): + with CLCache() as cache: c = a.sum(2) d = a.sum(2).permute(1,0) c.realize() d.realize() - cache_len = len(CacheCollector.cache) + cache_len = cache.count np.testing.assert_allclose(c.numpy().transpose(1,0), d.numpy(), rtol=1e-3, atol=1e-5) assert cache_len == 1, "reduceop was rerun!" @unittest.skipIf(PUSH_PERMUTES, "this test is broken with PUSH_PERMUTES") def test_no_reduceop_rerun_alt(self): a = Tensor.randn(16, 16, 16) - with CLCache(): + with CLCache() as cache: c = a.sum(2).permute(1,0) d = a.sum(2) c.realize() d.realize() - cache_len = len(CacheCollector.cache) + cache_len = cache.count np.testing.assert_allclose(c.numpy(), d.numpy().transpose(1,0), rtol=1e-3, atol=1e-5) assert cache_len == 1, "reduceop was rerun!" diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 89a670c3dd..c040dceaa5 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -9,9 +9,8 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node from tinygrad.tensor import Tensor -from tinygrad.engine.jit import CacheCollector from tinygrad.engine.schedule import create_schedule -from tinygrad.engine.realize import run_schedule +from tinygrad.engine.realize import run_schedule, lower_schedule from tinygrad.helpers import prod, Context, getenv from tinygrad.dtype import DType, dtypes from tinygrad.codegen.uops import UOpGraph @@ -20,9 +19,10 @@ class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): a, b = Tensor.randn(4), Tensor.randn(4) np_a, np_b = a.numpy(), b.numpy() - CacheCollector.start() - c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))).realize() - rawbufs = CacheCollector.finish()[0].rawbufs + c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))) + lowered = list(lower_schedule(create_schedule([c.lazydata]))) + for ei in lowered: ei.run() + rawbufs = lowered[-1].rawbufs assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized} np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:]) np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4) diff --git a/tinygrad/device.py b/tinygrad/device.py index 01508a6f2c..909c9b9c57 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -3,7 +3,7 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, NamedTuple import importlib, inspect, functools, pathlib, time, ctypes, os from tinygrad.helpers import ansilen, prod, getenv, colored, all_int, to_function_name, from_mv, flat_mv, diskcache_get, diskcache_put -from tinygrad.helpers import DEBUG, CACHECOLLECTING, BEAM, NOOPT, GlobalCounters +from tinygrad.helpers import DEBUG, BEAM, NOOPT, GlobalCounters from tinygrad.shape.symbolic import Variable, sym_infer, sint from tinygrad.ops import LazyOp, get_lazyop_info from tinygrad.buffer import Buffer, BufferOptions @@ -44,11 +44,7 @@ class Runner: self.op_estimate:sint = 0 self.mem_estimate:sint = 0 def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: - var_vals = var_vals if var_vals is not None else {} - from tinygrad.engine.jit import CacheCollector - et = self(rawbufs, var_vals) - if CACHECOLLECTING: CacheCollector.add(self, rawbufs, var_vals) - return et + return self(rawbufs, {} if var_vals is None else var_vals) def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: raise NotImplementedError("override this") diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 126e01dc3e..5016b93ee3 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -1,33 +1,28 @@ from __future__ import annotations -from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic +from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional import functools, itertools, operator -from tinygrad.nn.state import get_parameters -from tinygrad.dtype import DType -from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException -from tinygrad.device import Compiled, Runner, CompiledRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device +from dataclasses import dataclass from tinygrad.tensor import Tensor from tinygrad.lazy import LazyBuffer -from tinygrad.features.multi import MultiLazyBuffer +from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, all_int, GraphException +from tinygrad.device import Buffer, Runner, CompiledRunner, BufferXfer, Compiled, MultiDeviceJITGraph, Device +from tinygrad.dtype import DType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, sint -from tinygrad.engine.realize import ExecItem +from tinygrad.engine.realize import ExecItem, capturing +from tinygrad.nn.state import get_parameters from weakref import ref, WeakKeyDictionary +# TODO: these graph functions probably shouldn't exist here + def get_jit_stats(jit_cache: List[ExecItem]) -> Tuple[sint, int]: return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0), \ functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledRunner)], 0) -def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]: - input_replace: Dict[Tuple[int, int], int] = {} - for j,ji in enumerate(jit_cache): - for i,a in enumerate(ji.rawbufs): - if a in input_rawbuffers: - input_replace[(j,i)] = input_rawbuffers.index(a) - return input_replace def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[ExecItem]) -> List[int]: - return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501 + return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and \ + ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] def get_jc_idxs_with_updatable_var_vals(jit_cache: List[ExecItem]) -> List[int]: return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledRunner) and ji.prg.vars] - def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]: # Split JIT cache into batches for faster graph execution. # This allows the accelerator to run some batches while subsequent graphs are still being updated. @@ -68,7 +63,38 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer] if len(current_batch) > 0: flush_batch() return graphed_jit_cache -# *** JIT *** +def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]: + input_replace: Dict[Tuple[int, int], int] = {} + for j,ji in enumerate(jit_cache): + for i,a in enumerate(ji.rawbufs): + if a in input_rawbuffers: + input_replace[(j,i)] = input_rawbuffers.index(a) + return input_replace + +class PlaceHolder: + placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary() + def __init__(self, buf:Buffer): + self.size, self.dtype, self.device, self.ref, self.bufid, self.options = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf), buf.options + def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid, self.options) + def __hash__(self): return hash(self.to_tuple()) + def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple() + @staticmethod + def create_if_needed(buf:Buffer) -> Union[PlaceHolder, Buffer]: + if found:=PlaceHolder.placeholders.get(buf, None): return found + if hasattr(buf, '_buf'): return buf + PlaceHolder.placeholders[buf] = ret = PlaceHolder(buf.ensure_allocated()) # TODO: do I need to allocate here? + return ret + + def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer: + ret = self.ref() + if ret: return ret + if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options).allocate() + return buffer_cache[self] + +@dataclass(frozen=True) +class WeakExecItem: + prg: Runner + rawbufs: List[Union[PlaceHolder, Buffer]] ReturnType = TypeVar('ReturnType') class TinyJit(Generic[ReturnType]): @@ -76,62 +102,56 @@ class TinyJit(Generic[ReturnType]): self.fxn = fxn self.reset() + def add(self, ei:ExecItem): + self._cc.append(WeakExecItem(ei.prg, [PlaceHolder.create_if_needed(buf) for buf in ei.rawbufs if buf is not None])) + def reset(self): + self._cc: List[WeakExecItem] = [] self.jit_cache: List[ExecItem] = [] self.input_replace: Dict[Tuple[int, int], int] = {} self.cnt: int = 0 - self.ret: Optional[ReturnType] = None - self.expected_vals: Optional[Tuple[Variable, ...]] = None - self.expected_name_sts_dtype_device: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType, Union[str, Tuple[str, ...]]], ...]] = None - # add support for instance methods - def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) + def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods def __call__(self, *args, **kwargs) -> ReturnType: - # all inputs (except const) are realized - input_tensors: Dict[Union[int, str], Tensor] = { cast(Union[int, str], k):v for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor } # noqa: E501 - Tensor.corealize(input_tensors.values()) - input_lbs: Dict[Union[int, str], Union[LazyBuffer, MultiLazyBuffer]] = {k:v.lazydata for k,v in input_tensors.items()} - expected_name_sts_dtype_device = tuple([(k, v.st.unbind()[0] if isinstance(v, LazyBuffer) else ShapeTracker.from_shape(v.shape), v.dtype, v.device) for k,v in input_lbs.items()]) #noqa: E501 - - # get rawbuffers - lbs: List[LazyBuffer] = [v for v in input_lbs.values() if isinstance(v, LazyBuffer)] + \ - flatten([mlb.lbs for mlb in input_lbs.values() if isinstance(mlb, MultiLazyBuffer)]) + input_tensors: List[Tuple[Union[int, str], Tensor]] = \ + [(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor] + Tensor.corealize([x[1] for x in input_tensors]) + lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors]) + expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs] input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None] assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT" + var_vals: Dict[Variable, int] = merge_dicts([x[1] for x in expected_sts_var_dtype_device] + \ + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) - # get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global - var_vals: Dict[Variable, int] = merge_dicts([arg.st.var_vals for arg in lbs] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) # noqa: E501 - expected_vals = tuple(var_vals.keys()) - + expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device] if self.cnt >= 2: # jit exec - assert self.expected_vals == expected_vals and self.expected_name_sts_dtype_device is not None, "missing/mismatch of var_vals" - assert all(x[0] == y[0] and x[1].views == y[1].views and x[2] == y[2] and x[3] == y[3] - for x,y in zip(self.expected_name_sts_dtype_device, expected_name_sts_dtype_device)), \ - f"mismatch of input tensors, expected {self.expected_name_sts_dtype_device} got {expected_name_sts_dtype_device}" + assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT" for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] if DEBUG >= 1: print(f"jit execs {len(self.jit_cache)} kernels") - for ji in self.jit_cache: ji.prg(cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG>=2, jit=True) + for ei in self.jit_cache: ei.run(var_vals, jit=True) elif self.cnt == 1: # jit capture - self.expected_vals, self.expected_name_sts_dtype_device = expected_vals, expected_name_sts_dtype_device - CacheCollector.start(var_vals) - with Context(GRAPH=getenv("JITGRAPH", GRAPH.value)): + self.expected_names: List[Union[int, str]] = expected_names + self.expected_lbs: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = expected_lbs + with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)): + capturing.append(self) self.ret = self.fxn(*args, **kwargs) Tensor.corealize(get_parameters(self.ret)) - self.jit_cache = CacheCollector.finish() - assert len(self.jit_cache) != 0, "didn't JIT anything!" - # TODO: reset doesn't work if we delete this - #del self.fxn - if DEBUG >= 1 and len(set(get_input_replace(self.jit_cache, input_rawbuffers).values())) != len(input_rawbuffers): - print("WARNING: some input tensors not found") + capturing.clear() + assert len(self._cc), "didn't JIT anything!" + buffer_cache: Dict[PlaceHolder, Buffer] = {} + self.jit_cache = \ + [ExecItem(ei.prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in ei.rawbufs]) for ei in self._cc] + del self._cc if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") # Condense the items into a graph executor. if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals) self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers) + if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found") elif self.cnt == 0: # jit ignore self.ret = self.fxn(*args, **kwargs) @@ -141,42 +161,4 @@ class TinyJit(Generic[ReturnType]): for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None self.cnt += 1 - return cast(ReturnType, self.ret) - -class PlaceHolder: - def __init__(self, buf:Buffer): - self.size, self.dtype, self.device, self.ref, self.bufid, self.options = buf.size, buf.dtype, buf.device, ref(buf), id(buf._buf), buf.options - def to_tuple(self): return (self.size, self.dtype, self.device, self.bufid, self.options) - def __hash__(self): return hash(self.to_tuple()) - def __eq__(self, x): return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple() - def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer: - ret = self.ref() - if ret: return ret - if self not in buffer_cache: buffer_cache[self] = Buffer(self.device, self.size, self.dtype, options=self.options).allocate() - return buffer_cache[self] - -class _CacheCollector: - def __init__(self): - self.cache: Optional[List[Tuple[Runner, List[Union[Buffer, PlaceHolder]]]]] = None - - def start(self, var_vals:Optional[Dict[Variable, int]]=None): - self.cache = [] - self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary() - self.var_vals = var_vals if var_vals is not None else {} - - def add(self, prg, rawbufs:List[Buffer], var_vals:Dict[Variable, int]): - if self.cache is None: return - for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}" - - # Buffer optimization is allowed only for kernel operations. Avoids for copies (prevents parallelism) and syncs (incorrect buffer reuse). - if isinstance(prg, CompiledRunner): - for i in range(prg.outcount): self.placeholders[rawbufs[i]] = PlaceHolder(rawbufs[i]) - - self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs])) - - def finish(self) -> List[ExecItem]: - if self.cache is None: return [] - buffer_cache: Dict[PlaceHolder, Buffer] = {} - saved_cache, self.cache = self.cache, None - return [ExecItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache] -CacheCollector = _CacheCollector() + return self.ret diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 0d3ea40ef1..dcbc1855fb 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -10,8 +10,8 @@ from tinygrad.shape.symbolic import Variable class ExecItem: prg: Runner rawbufs: List[Optional[Buffer]] - def run(self, var_vals:Optional[Dict[Variable, int]]=None): - self.prg.exec([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {}) + def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False): + self.prg([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {}, wait=wait, jit=jit) class CustomOp(Runner): def __init__(self, fxn): @@ -38,5 +38,9 @@ def lower_schedule_item(si:ScheduleItem) -> Runner: def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]: while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.outputs+si.inputs)) +capturing: List = [] # put classes with an add method in here + def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None): - for ei in lower_schedule(schedule): ei.run(var_vals) + for ei in lower_schedule(schedule): + if len(capturing): capturing[0].add(ei) + ei.run(var_vals) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index b31a8dca1f..4466203348 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -52,6 +52,10 @@ class LazyBuffer: @property def base(self) -> LazyBuffer: return self._base if self._base is not None else self + # same API as multi + @property + def lbs(self) -> List[LazyBuffer]: return [self] + @staticmethod def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Tuple[LazyBuffer, ...]=(), enable_cache=False) -> LazyBuffer: assert isinstance(src, tuple)