diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 1974e97917..7b54390e5a 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -65,6 +65,8 @@ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+str def getenv(key:str, default=0): return type(default)(os.getenv(key, default)) def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix() +class GraphException(Exception): pass + class Context(contextlib.ContextDecorator): stack: ClassVar[List[dict[str, int]]] = [{}] def __init__(self, **kwargs): self.kwargs = kwargs diff --git a/tinygrad/jit.py b/tinygrad/jit.py index b428c08167..a94bbe6694 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -3,7 +3,7 @@ from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, import functools, itertools, operator from tinygrad.nn.state import get_parameters from tinygrad.dtype import DType -from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten +from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer from tinygrad.tensor import Tensor from tinygrad.lazy import LazyBuffer @@ -32,7 +32,45 @@ def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]: return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars] -class GraphException(Exception): pass +def apply_graph_to_jit(jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[JitItem]: + # Split JIT cache into batches for faster graph execution. + # This allows the accelerator to run some batches while subsequent graphs are still being updated. + graphed_jit_cache: List[JitItem] = [] + current_batch: List[JitItem] = [] + current_device: Union[Compiled, None] = None + + # Flush the current batch. + def flush(): + nonlocal current_batch, current_device + assert current_device is not None + try: + graphed_jit_cache.append(JitItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501 + if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}") + except GraphException as e: + graphed_jit_cache.extend(current_batch) + if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}") + current_batch = [] + current_device = None + + for i,ji in enumerate(jit_cache): + # If the jit item can potentially be graphed, put it in a batch. + can_be_graphed = isinstance(ji.prg, CompiledASTRunner) and ji.prg.device.graph + if can_be_graphed: + assert isinstance(ji.prg, CompiledASTRunner) + # If the device changed we flush the batch early and append this item for the next batch. + if current_device is not None and ji.prg.device != current_device: flush() + current_device = ji.prg.device + current_batch.append(ji) + + # The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or + # (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added. + if len(current_batch) > 0 and (i==len(jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not can_be_graphed): flush() + + # If the jit item cannot be graphed, put it right into the final cache after the flush. + if not can_be_graphed: graphed_jit_cache.append(ji) + return graphed_jit_cache + +# *** JIT *** ReturnType = TypeVar('ReturnType') class TinyJit(Generic[ReturnType]): @@ -87,44 +125,7 @@ class TinyJit(Generic[ReturnType]): if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs") # Condense the items into a graph executor. - if getenv("JIT") != 2: - # Split JIT cache into batches for faster graph execution. - # This allows the accelerator to run some batches while subsequent graphs are still being updated. - graphed_jit_cache: List[JitItem] = [] - current_batch: List[JitItem] = [] - current_device: Union[Compiled, None] = None - - # Flush the current batch. - def flush(): - nonlocal current_batch, current_device - assert current_device is not None - try: - graphed_jit_cache.append(JitItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501 - if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}") - except GraphException as e: - graphed_jit_cache.extend(current_batch) - if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}") - current_batch = [] - current_device = None - - for i,ji in enumerate(self.jit_cache): - # If the jit item can potentially be graphed, put it in a batch. - can_be_graphed = isinstance(ji.prg, CompiledASTRunner) and ji.prg.device.graph - if can_be_graphed: - assert isinstance(ji.prg, CompiledASTRunner) - # If the device changed we flush the batch early and append this item for the next batch. - if current_device is not None and ji.prg.device != current_device: flush() - current_device = ji.prg.device - current_batch.append(ji) - - # The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or - # (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added. - if len(current_batch) > 0 and (i==len(self.jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not can_be_graphed): flush() # noqa: E501 - - # If the jit item cannot be graphed, put it right into the final cache after the flush. - if not can_be_graphed: graphed_jit_cache.append(ji) - - self.jit_cache = graphed_jit_cache + if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals) self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers) elif self.cnt == 0: diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 1be2526ae5..fc981df9da 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -1,11 +1,11 @@ import ctypes from typing import Any, Optional, Tuple, Dict, List, cast import gpuctypes.cuda as cuda -from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same +from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same, GraphException from tinygrad.device import CompiledASTRunner, update_stats, Buffer from tinygrad.runtime.ops_cuda import check, cu_time_execution from tinygrad.shape.symbolic import Variable -from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException # noqa: E501 +from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals class CUDAGraph: def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index cd0154221f..686ea8ea49 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -1,9 +1,9 @@ from typing import List, Any, Dict, cast, Optional import Metal from tinygrad.dtype import dtypes -from tinygrad.helpers import dedup, unwrap2 +from tinygrad.helpers import dedup, unwrap2, GraphException from tinygrad.device import Buffer, CompiledASTRunner, update_stats -from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, GraphException +from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims from tinygrad.shape.symbolic import Variable from tinygrad.runtime.ops_metal import MetalDevice diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index db97c85faa..b7dc27fbc3 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -112,7 +112,7 @@ class Tensor: @staticmethod def corealize(lst:Iterable[Tensor]): - return run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst]))) + run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst]))) def realize(self) -> Tensor: run_schedule(self.lazydata.schedule())