|
|
|
|
@@ -3,7 +3,7 @@ from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar,
|
|
|
|
|
import functools, itertools, operator
|
|
|
|
|
from tinygrad.nn.state import get_parameters
|
|
|
|
|
from tinygrad.dtype import DType
|
|
|
|
|
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten
|
|
|
|
|
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException
|
|
|
|
|
from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer
|
|
|
|
|
from tinygrad.tensor import Tensor
|
|
|
|
|
from tinygrad.lazy import LazyBuffer
|
|
|
|
|
@@ -32,7 +32,45 @@ def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int
|
|
|
|
|
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
|
|
|
|
|
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
|
|
|
|
|
|
|
|
|
|
class GraphException(Exception): pass
|
|
|
|
|
def apply_graph_to_jit(jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[JitItem]:
|
|
|
|
|
# Split JIT cache into batches for faster graph execution.
|
|
|
|
|
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
|
|
|
|
graphed_jit_cache: List[JitItem] = []
|
|
|
|
|
current_batch: List[JitItem] = []
|
|
|
|
|
current_device: Union[Compiled, None] = None
|
|
|
|
|
|
|
|
|
|
# Flush the current batch.
|
|
|
|
|
def flush():
|
|
|
|
|
nonlocal current_batch, current_device
|
|
|
|
|
assert current_device is not None
|
|
|
|
|
try:
|
|
|
|
|
graphed_jit_cache.append(JitItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
|
|
|
|
|
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
|
|
|
|
|
except GraphException as e:
|
|
|
|
|
graphed_jit_cache.extend(current_batch)
|
|
|
|
|
if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
|
|
|
|
|
current_batch = []
|
|
|
|
|
current_device = None
|
|
|
|
|
|
|
|
|
|
for i,ji in enumerate(jit_cache):
|
|
|
|
|
# If the jit item can potentially be graphed, put it in a batch.
|
|
|
|
|
can_be_graphed = isinstance(ji.prg, CompiledASTRunner) and ji.prg.device.graph
|
|
|
|
|
if can_be_graphed:
|
|
|
|
|
assert isinstance(ji.prg, CompiledASTRunner)
|
|
|
|
|
# If the device changed we flush the batch early and append this item for the next batch.
|
|
|
|
|
if current_device is not None and ji.prg.device != current_device: flush()
|
|
|
|
|
current_device = ji.prg.device
|
|
|
|
|
current_batch.append(ji)
|
|
|
|
|
|
|
|
|
|
# The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or
|
|
|
|
|
# (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added.
|
|
|
|
|
if len(current_batch) > 0 and (i==len(jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not can_be_graphed): flush()
|
|
|
|
|
|
|
|
|
|
# If the jit item cannot be graphed, put it right into the final cache after the flush.
|
|
|
|
|
if not can_be_graphed: graphed_jit_cache.append(ji)
|
|
|
|
|
return graphed_jit_cache
|
|
|
|
|
|
|
|
|
|
# *** JIT ***
|
|
|
|
|
|
|
|
|
|
ReturnType = TypeVar('ReturnType')
|
|
|
|
|
class TinyJit(Generic[ReturnType]):
|
|
|
|
|
@@ -87,44 +125,7 @@ class TinyJit(Generic[ReturnType]):
|
|
|
|
|
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
|
|
|
|
|
|
|
|
|
# Condense the items into a graph executor.
|
|
|
|
|
if getenv("JIT") != 2:
|
|
|
|
|
# Split JIT cache into batches for faster graph execution.
|
|
|
|
|
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
|
|
|
|
graphed_jit_cache: List[JitItem] = []
|
|
|
|
|
current_batch: List[JitItem] = []
|
|
|
|
|
current_device: Union[Compiled, None] = None
|
|
|
|
|
|
|
|
|
|
# Flush the current batch.
|
|
|
|
|
def flush():
|
|
|
|
|
nonlocal current_batch, current_device
|
|
|
|
|
assert current_device is not None
|
|
|
|
|
try:
|
|
|
|
|
graphed_jit_cache.append(JitItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
|
|
|
|
|
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
|
|
|
|
|
except GraphException as e:
|
|
|
|
|
graphed_jit_cache.extend(current_batch)
|
|
|
|
|
if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
|
|
|
|
|
current_batch = []
|
|
|
|
|
current_device = None
|
|
|
|
|
|
|
|
|
|
for i,ji in enumerate(self.jit_cache):
|
|
|
|
|
# If the jit item can potentially be graphed, put it in a batch.
|
|
|
|
|
can_be_graphed = isinstance(ji.prg, CompiledASTRunner) and ji.prg.device.graph
|
|
|
|
|
if can_be_graphed:
|
|
|
|
|
assert isinstance(ji.prg, CompiledASTRunner)
|
|
|
|
|
# If the device changed we flush the batch early and append this item for the next batch.
|
|
|
|
|
if current_device is not None and ji.prg.device != current_device: flush()
|
|
|
|
|
current_device = ji.prg.device
|
|
|
|
|
current_batch.append(ji)
|
|
|
|
|
|
|
|
|
|
# The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or
|
|
|
|
|
# (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added.
|
|
|
|
|
if len(current_batch) > 0 and (i==len(self.jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not can_be_graphed): flush() # noqa: E501
|
|
|
|
|
|
|
|
|
|
# If the jit item cannot be graphed, put it right into the final cache after the flush.
|
|
|
|
|
if not can_be_graphed: graphed_jit_cache.append(ji)
|
|
|
|
|
|
|
|
|
|
self.jit_cache = graphed_jit_cache
|
|
|
|
|
if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)
|
|
|
|
|
|
|
|
|
|
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
|
|
|
|
|
elif self.cnt == 0:
|
|
|
|
|
|