JIT cleanups (#3164)

* move GraphException

* factor out apply_graph_to_jit

* that return was wrong
This commit is contained in:
George Hotz
2024-01-17 23:39:57 -08:00
committed by GitHub
parent f0c178b7e9
commit 67bc2ddfd8
5 changed files with 48 additions and 45 deletions

View File

@@ -65,6 +65,8 @@ def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+str
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
class GraphException(Exception): pass
class Context(contextlib.ContextDecorator):
stack: ClassVar[List[dict[str, int]]] = [{}]
def __init__(self, **kwargs): self.kwargs = kwargs

View File

@@ -3,7 +3,7 @@ from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar,
import functools, itertools, operator
from tinygrad.nn.state import get_parameters
from tinygrad.dtype import DType
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException
from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
@@ -32,7 +32,45 @@ def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
class GraphException(Exception): pass
def apply_graph_to_jit(jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[JitItem]:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
graphed_jit_cache: List[JitItem] = []
current_batch: List[JitItem] = []
current_device: Union[Compiled, None] = None
# Flush the current batch.
def flush():
nonlocal current_batch, current_device
assert current_device is not None
try:
graphed_jit_cache.append(JitItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
except GraphException as e:
graphed_jit_cache.extend(current_batch)
if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
current_batch = []
current_device = None
for i,ji in enumerate(jit_cache):
# If the jit item can potentially be graphed, put it in a batch.
can_be_graphed = isinstance(ji.prg, CompiledASTRunner) and ji.prg.device.graph
if can_be_graphed:
assert isinstance(ji.prg, CompiledASTRunner)
# If the device changed we flush the batch early and append this item for the next batch.
if current_device is not None and ji.prg.device != current_device: flush()
current_device = ji.prg.device
current_batch.append(ji)
# The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or
# (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added.
if len(current_batch) > 0 and (i==len(jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not can_be_graphed): flush()
# If the jit item cannot be graphed, put it right into the final cache after the flush.
if not can_be_graphed: graphed_jit_cache.append(ji)
return graphed_jit_cache
# *** JIT ***
ReturnType = TypeVar('ReturnType')
class TinyJit(Generic[ReturnType]):
@@ -87,44 +125,7 @@ class TinyJit(Generic[ReturnType]):
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# Condense the items into a graph executor.
if getenv("JIT") != 2:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
graphed_jit_cache: List[JitItem] = []
current_batch: List[JitItem] = []
current_device: Union[Compiled, None] = None
# Flush the current batch.
def flush():
nonlocal current_batch, current_device
assert current_device is not None
try:
graphed_jit_cache.append(JitItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
except GraphException as e:
graphed_jit_cache.extend(current_batch)
if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
current_batch = []
current_device = None
for i,ji in enumerate(self.jit_cache):
# If the jit item can potentially be graphed, put it in a batch.
can_be_graphed = isinstance(ji.prg, CompiledASTRunner) and ji.prg.device.graph
if can_be_graphed:
assert isinstance(ji.prg, CompiledASTRunner)
# If the device changed we flush the batch early and append this item for the next batch.
if current_device is not None and ji.prg.device != current_device: flush()
current_device = ji.prg.device
current_batch.append(ji)
# The flush is done when (1) ji is the last one, (2) the size of batch exceeds the maximum batch size or
# (3) the current jit item cannot be graphed, so the current batch is flushed before such a jit item is added.
if len(current_batch) > 0 and (i==len(self.jit_cache)-1 or len(current_batch) >= getenv("JIT_BATCH_SIZE", 64) or not can_be_graphed): flush() # noqa: E501
# If the jit item cannot be graphed, put it right into the final cache after the flush.
if not can_be_graphed: graphed_jit_cache.append(ji)
self.jit_cache = graphed_jit_cache
if getenv("JIT") != 2: self.jit_cache = apply_graph_to_jit(self.jit_cache, input_rawbuffers, var_vals)
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
elif self.cnt == 0:

View File

@@ -1,11 +1,11 @@
import ctypes
from typing import Any, Optional, Tuple, Dict, List, cast
import gpuctypes.cuda as cuda
from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same
from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same, GraphException
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
from tinygrad.runtime.ops_cuda import check, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals, GraphException # noqa: E501
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
class CUDAGraph:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):

View File

@@ -1,9 +1,9 @@
from typing import List, Any, Dict, cast, Optional
import Metal
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, unwrap2
from tinygrad.helpers import dedup, unwrap2, GraphException
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, GraphException
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_metal import MetalDevice

View File

@@ -112,7 +112,7 @@ class Tensor:
@staticmethod
def corealize(lst:Iterable[Tensor]):
return run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))
run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))
def realize(self) -> Tensor:
run_schedule(self.lazydata.schedule())