JitItem -> ExecItem (#4146)

* JitItem -> ExecItem

* execitem in realize

* cleaner

* JITRunner -> Runner
This commit is contained in:
George Hotz
2024-04-11 08:24:57 -07:00
committed by GitHub
parent e79a11b99c
commit b7e281cf10
10 changed files with 55 additions and 58 deletions

View File

@@ -4,7 +4,7 @@ from typing import Tuple, TypeVar, List, Any, cast, Set
import tinygrad.runtime.autogen.hip as hip
from tinygrad.helpers import DEBUG, getenv, init_c_var
from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t
from tinygrad.device import Compiled, LRUAllocator, BufferOptions, JITRunner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
from tinygrad.device import Compiled, LRUAllocator, BufferOptions, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
from tinygrad.renderer.cstyle import HIPRenderer
from tinygrad.runtime.driver.hip_comgr import compile_hip
@@ -128,7 +128,7 @@ class HIPAllocator(LRUAllocator):
hip_set_device(self.device.device)
check(hip.hipMemcpyAsync(dest, src, sz, hip.hipMemcpyDeviceToDevice, None))
class HIPSyncEvent(JITRunner):
class HIPSyncEvent(Runner):
def __init__(self, lb):
self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device
super().__init__()
@@ -138,7 +138,7 @@ class HIPSyncEvent(JITRunner):
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.dname)
class HIPWaitEvent(JITRunner):
class HIPWaitEvent(Runner):
def __init__(self, device):
self.device, self.dname = cast(HIPDevice, Device[device]), device
super().__init__()

View File

@@ -4,7 +4,7 @@ from tinygrad.device import Device, Buffer, BufferXfer
from tinygrad.dtype import dtypes
from tinygrad.runtime.driver.hsa import AQLQueue
from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
from tinygrad.engine.jit import JitItem
from tinygrad.engine.realize import ExecItem
def get_hsa_inc_prog(dev, inc=1):
prg = f"""
@@ -102,7 +102,7 @@ class TestHSADriver(unittest.TestCase):
test_buf1.copyin(memoryview(bytearray(1*4)))
test_buf2.copyin(memoryview(bytearray(1*4)))
jit_cache = [JitItem(BufferXfer(), [test_buf0, test_buf2]), JitItem(BufferXfer(), [test_buf2, test_buf1])]
jit_cache = [ExecItem(BufferXfer(), [test_buf0, test_buf2]), ExecItem(BufferXfer(), [test_buf2, test_buf1])]
graph = HSAGraph(jit_cache, [], {})
for i in range(10000):

View File

@@ -1,6 +1,6 @@
import sys
from tinygrad import Tensor, Device, dtypes
from tinygrad.device import JITRunner
from tinygrad.device import Runner
from tinygrad.dtype import DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import Context, CI, OSX, getenv
@@ -13,12 +13,12 @@ def derandomize_model(model):
def assert_jit_cache_len(fxn, expected_len):
assert len(fxn.jit_cache) > 0
# until we have a better way of typing the prg in JitItem
if issubclass(type(fxn.jit_cache[0].prg), JITRunner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
# until we have a better way of typing the prg in ExecItem
if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
assert len(fxn.jit_cache) == expected_len
else:
assert len(fxn.jit_cache) == 1
# until we have a better way of typing the prg in JitItem
# until we have a better way of typing the prg in ExecItem
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len

View File

@@ -39,7 +39,7 @@ Device = _Device()
# **************** base Runner + helpers ****************
class JITRunner:
class Runner:
def __init__(self):
self.op_estimate:sint = 0
self.mem_estimate:sint = 0
@@ -67,7 +67,7 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option
# **************** Buffer / Allocator ****************
class BufferCopy(JITRunner):
class BufferCopy(Runner):
def copy(self, dest, src):
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_fd') and src.nbytes >= 4096 and hasattr(src.allocator.device, 'fd'):
dest.allocator.copy_from_fd(dest._buf, src.allocator.device.fd, src._buf.offset, src.nbytes)
@@ -158,7 +158,7 @@ class Compiler:
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
return lib
class CompiledASTRunner(JITRunner):
class CompiledASTRunner(Runner):
def __init__(self, name:str, prg:str, dname:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None,
variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None, outcount:int=1):
super().__init__()
@@ -201,7 +201,7 @@ class CompiledASTRunner(JITRunner):
self.first_run = False
return et
class MultiDeviceJITGraph(JITRunner):
class MultiDeviceJITGraph(Runner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")

View File

@@ -4,48 +4,43 @@ import functools, itertools, operator
from tinygrad.nn.state import get_parameters
from tinygrad.dtype import DType
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException
from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device
from tinygrad.device import Compiled, Runner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.engine.realize import ExecItem
from weakref import ref, WeakKeyDictionary
from dataclasses import dataclass
@dataclass(frozen=True)
class JitItem:
prg: JITRunner # or a graph executor like MetalGraph
rawbufs: List[Optional[Buffer]]
def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[sint, int]:
def get_jit_stats(jit_cache: List[ExecItem]) -> Tuple[sint, int]:
return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledASTRunner)], 0), \
functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledASTRunner)], 0)
def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
input_replace: Dict[Tuple[int, int], int] = {}
for j,ji in enumerate(jit_cache):
for i,a in enumerate(ji.rawbufs):
if a in input_rawbuffers:
input_replace[(j,i)] = input_rawbuffers.index(a)
return input_replace
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[ExecItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[ExecItem]) -> List[int]:
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
def apply_graph_to_jit(jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[JitItem]:
def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
max_batch_size = getenv("JIT_BATCH_SIZE", 32)
graphed_jit_cache: List[JitItem] = []
current_batch: List[JitItem] = []
graphed_jit_cache: List[ExecItem] = []
current_batch: List[ExecItem] = []
current_device: Optional[Compiled] = None
def flush_batch():
nonlocal current_batch, current_device, max_batch_size
try:
if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph")
graphed_jit_cache.append(JitItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
graphed_jit_cache.append(ExecItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
max_batch_size *= 2
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
except GraphException as e:
@@ -82,7 +77,7 @@ class TinyJit(Generic[ReturnType]):
self.reset()
def reset(self):
self.jit_cache: List[JitItem] = []
self.jit_cache: List[ExecItem] = []
self.input_replace: Dict[Tuple[int, int], int] = {}
self.cnt: int = 0
self.ret: Optional[ReturnType] = None
@@ -162,7 +157,7 @@ class PlaceHolder:
class _CacheCollector:
def __init__(self):
self.cache: Optional[List[Tuple[JITRunner, List[Union[Buffer, PlaceHolder]]]]] = None
self.cache: Optional[List[Tuple[Runner, List[Union[Buffer, PlaceHolder]]]]] = None
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
self.cache = []
@@ -179,9 +174,9 @@ class _CacheCollector:
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs]))
def finish(self) -> List[JitItem]:
def finish(self) -> List[ExecItem]:
if self.cache is None: return []
buffer_cache: Dict[PlaceHolder, Buffer] = {}
saved_cache, self.cache = self.cache, None
return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
return [ExecItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
CacheCollector = _CacheCollector()

View File

@@ -1,21 +1,29 @@
from typing import List, Dict, Optional
from typing import List, Dict, Optional, cast, Generator
from dataclasses import dataclass
from tinygrad.helpers import colored
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps
from tinygrad.device import JITRunner, Device, BufferCopy, BufferXfer, update_stats
from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats
from tinygrad.buffer import Buffer
from tinygrad.shape.symbolic import Variable
class CustomOp(JITRunner):
@dataclass(frozen=True)
class ExecItem:
prg: Runner
rawbufs: List[Optional[Buffer]]
def run(self, var_vals:Optional[Dict[Variable, int]]=None):
self.prg.exec([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {})
class CustomOp(Runner):
def __init__(self, fxn):
self.fxn = fxn
super().__init__()
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
class EmptyOp(JITRunner):
class EmptyOp(Runner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
update_stats(colored(f"empty {rawbufs[0].size:10d} {rawbufs[0].dtype}", "yellow"), 0, 0, {}, jit, 1, device=rawbufs[0].device)
def lower_schedule_item(si:ScheduleItem) -> JITRunner:
def lower_schedule_item(si:ScheduleItem) -> Runner:
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
@@ -27,15 +35,8 @@ def lower_schedule_item(si:ScheduleItem) -> JITRunner:
if ast.op is LoadOps.EMPTY: return EmptyOp()
raise RuntimeError(f"don't know how to lower {ast}")
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]] = None):
while len(schedule):
si = schedule.pop(0)
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.outputs+si.inputs))
# get the program
prg = lower_schedule_item(si)
# allocate output buffers
for out in si.outputs: out.ensure_allocated()
# run the function (put it in JIT)
prg.exec(list(si.outputs+si.inputs), var_vals if var_vals is not None else {})
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
for ei in lower_schedule(schedule): ei.run(var_vals)

View File

@@ -5,11 +5,11 @@ from tinygrad.helpers import init_c_var, GraphException, getenv
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
class CUDAGraph(MultiDeviceJITGraph):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
# Check all jit items are compatible.
if not all(isinstance(ji.prg, CompiledASTRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException

View File

@@ -5,8 +5,8 @@ from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Compiled, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
import tinygrad.runtime.autogen.hsa as hsa
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
@@ -26,7 +26,7 @@ class VirtAQLQueue(AQLQueue):
self.available_packet_slots -= 1
class HSAGraph(MultiDeviceJITGraph):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) #type:ignore

View File

@@ -3,12 +3,13 @@ import Metal
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, unwrap2, GraphException
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_metal import MetalDevice, wait_check
class MetalGraph:
def __init__(self, device:MetalDevice, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException
self.jit_cache = jit_cache

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import os, mmap, _posixshmem, io, functools
from typing import Dict, List, Any, Optional
from tinygrad.helpers import prod, OSX
from tinygrad.device import Compiled, Allocator, JITRunner, Buffer
from tinygrad.device import Compiled, Allocator, Runner, Buffer
from tinygrad.ops import UnaryOps, LazyOp, BufferOps
from tinygrad.shape.view import strides_for_shape
@@ -32,7 +32,7 @@ class DiskAllocator(Allocator):
else:
dest[:] = src._buf()
class DiskRunner(JITRunner):
class DiskRunner(Runner):
def __init__(self, ast:LazyOp):
# two ASTs are allowed here.
assert ast.op is BufferOps.STORE, "output of AST must be store"