diff --git a/extra/backends/ops_hip.py b/extra/backends/ops_hip.py index 4c6a834211..2ffc6d58a6 100644 --- a/extra/backends/ops_hip.py +++ b/extra/backends/ops_hip.py @@ -4,7 +4,7 @@ from typing import Tuple, TypeVar, List, Any, cast, Set import tinygrad.runtime.autogen.hip as hip from tinygrad.helpers import DEBUG, getenv, init_c_var from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t -from tinygrad.device import Compiled, LRUAllocator, BufferOptions, JITRunner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions +from tinygrad.device import Compiled, LRUAllocator, BufferOptions, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions from tinygrad.renderer.cstyle import HIPRenderer from tinygrad.runtime.driver.hip_comgr import compile_hip @@ -128,7 +128,7 @@ class HIPAllocator(LRUAllocator): hip_set_device(self.device.device) check(hip.hipMemcpyAsync(dest, src, sz, hip.hipMemcpyDeviceToDevice, None)) -class HIPSyncEvent(JITRunner): +class HIPSyncEvent(Runner): def __init__(self, lb): self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device super().__init__() @@ -138,7 +138,7 @@ class HIPSyncEvent(JITRunner): check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0)) update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.dname) -class HIPWaitEvent(JITRunner): +class HIPWaitEvent(Runner): def __init__(self, device): self.device, self.dname = cast(HIPDevice, Device[device]), device super().__init__() diff --git a/test/external/external_test_hsa_driver.py b/test/external/external_test_hsa_driver.py index 4dd8cded33..827401be61 100644 --- a/test/external/external_test_hsa_driver.py +++ b/test/external/external_test_hsa_driver.py @@ -4,7 +4,7 @@ from tinygrad.device import Device, Buffer, BufferXfer from tinygrad.dtype import dtypes from tinygrad.runtime.driver.hsa import AQLQueue from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph -from tinygrad.engine.jit import JitItem +from tinygrad.engine.realize import ExecItem def get_hsa_inc_prog(dev, inc=1): prg = f""" @@ -102,7 +102,7 @@ class TestHSADriver(unittest.TestCase): test_buf1.copyin(memoryview(bytearray(1*4))) test_buf2.copyin(memoryview(bytearray(1*4))) - jit_cache = [JitItem(BufferXfer(), [test_buf0, test_buf2]), JitItem(BufferXfer(), [test_buf2, test_buf1])] + jit_cache = [ExecItem(BufferXfer(), [test_buf0, test_buf2]), ExecItem(BufferXfer(), [test_buf2, test_buf1])] graph = HSAGraph(jit_cache, [], {}) for i in range(10000): diff --git a/test/helpers.py b/test/helpers.py index 2a25f9996a..1d2bd857ee 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,6 +1,6 @@ import sys from tinygrad import Tensor, Device, dtypes -from tinygrad.device import JITRunner +from tinygrad.device import Runner from tinygrad.dtype import DType from tinygrad.nn.state import get_parameters from tinygrad.helpers import Context, CI, OSX, getenv @@ -13,12 +13,12 @@ def derandomize_model(model): def assert_jit_cache_len(fxn, expected_len): assert len(fxn.jit_cache) > 0 - # until we have a better way of typing the prg in JitItem - if issubclass(type(fxn.jit_cache[0].prg), JITRunner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'): + # until we have a better way of typing the prg in ExecItem + if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'): assert len(fxn.jit_cache) == expected_len else: assert len(fxn.jit_cache) == 1 - # until we have a better way of typing the prg in JitItem + # until we have a better way of typing the prg in ExecItem assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph') assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len diff --git a/tinygrad/device.py b/tinygrad/device.py index 357b8141b6..41f3a87ce3 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -39,7 +39,7 @@ Device = _Device() # **************** base Runner + helpers **************** -class JITRunner: +class Runner: def __init__(self): self.op_estimate:sint = 0 self.mem_estimate:sint = 0 @@ -67,7 +67,7 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option # **************** Buffer / Allocator **************** -class BufferCopy(JITRunner): +class BufferCopy(Runner): def copy(self, dest, src): if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_fd') and src.nbytes >= 4096 and hasattr(src.allocator.device, 'fd'): dest.allocator.copy_from_fd(dest._buf, src.allocator.device.fd, src._buf.offset, src.nbytes) @@ -158,7 +158,7 @@ class Compiler: if self.cachekey is not None: diskcache_put(self.cachekey, src, lib) return lib -class CompiledASTRunner(JITRunner): +class CompiledASTRunner(Runner): def __init__(self, name:str, prg:str, dname:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None, outcount:int=1): super().__init__() @@ -201,7 +201,7 @@ class CompiledASTRunner(JITRunner): self.first_run = False return et -class MultiDeviceJITGraph(JITRunner): +class MultiDeviceJITGraph(Runner): def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: raise NotImplementedError("override this") diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 9e1a00eec9..6cb4508e6b 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -4,48 +4,43 @@ import functools, itertools, operator from tinygrad.nn.state import get_parameters from tinygrad.dtype import DType from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException -from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device +from tinygrad.device import Compiled, Runner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device from tinygrad.tensor import Tensor from tinygrad.lazy import LazyBuffer from tinygrad.features.multi import MultiLazyBuffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, sint +from tinygrad.engine.realize import ExecItem from weakref import ref, WeakKeyDictionary -from dataclasses import dataclass -@dataclass(frozen=True) -class JitItem: - prg: JITRunner # or a graph executor like MetalGraph - rawbufs: List[Optional[Buffer]] - -def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[sint, int]: +def get_jit_stats(jit_cache: List[ExecItem]) -> Tuple[sint, int]: return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledASTRunner)], 0), \ functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledASTRunner)], 0) -def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]: +def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]: input_replace: Dict[Tuple[int, int], int] = {} for j,ji in enumerate(jit_cache): for i,a in enumerate(ji.rawbufs): if a in input_rawbuffers: input_replace[(j,i)] = input_rawbuffers.index(a) return input_replace -def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]: +def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[ExecItem]) -> List[int]: return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501 -def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]: +def get_jc_idxs_with_updatable_var_vals(jit_cache: List[ExecItem]) -> List[int]: return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars] -def apply_graph_to_jit(jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[JitItem]: +def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]: # Split JIT cache into batches for faster graph execution. # This allows the accelerator to run some batches while subsequent graphs are still being updated. max_batch_size = getenv("JIT_BATCH_SIZE", 32) - graphed_jit_cache: List[JitItem] = [] - current_batch: List[JitItem] = [] + graphed_jit_cache: List[ExecItem] = [] + current_batch: List[ExecItem] = [] current_device: Optional[Compiled] = None def flush_batch(): nonlocal current_batch, current_device, max_batch_size try: if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph") - graphed_jit_cache.append(JitItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501 + graphed_jit_cache.append(ExecItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501 max_batch_size *= 2 if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}") except GraphException as e: @@ -82,7 +77,7 @@ class TinyJit(Generic[ReturnType]): self.reset() def reset(self): - self.jit_cache: List[JitItem] = [] + self.jit_cache: List[ExecItem] = [] self.input_replace: Dict[Tuple[int, int], int] = {} self.cnt: int = 0 self.ret: Optional[ReturnType] = None @@ -162,7 +157,7 @@ class PlaceHolder: class _CacheCollector: def __init__(self): - self.cache: Optional[List[Tuple[JITRunner, List[Union[Buffer, PlaceHolder]]]]] = None + self.cache: Optional[List[Tuple[Runner, List[Union[Buffer, PlaceHolder]]]]] = None def start(self, var_vals:Optional[Dict[Variable, int]]=None): self.cache = [] @@ -179,9 +174,9 @@ class _CacheCollector: self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs])) - def finish(self) -> List[JitItem]: + def finish(self) -> List[ExecItem]: if self.cache is None: return [] buffer_cache: Dict[PlaceHolder, Buffer] = {} saved_cache, self.cache = self.cache, None - return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache] + return [ExecItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache] CacheCollector = _CacheCollector() diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 15dcedcaf4..0d3ea40ef1 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,21 +1,29 @@ -from typing import List, Dict, Optional +from typing import List, Dict, Optional, cast, Generator +from dataclasses import dataclass from tinygrad.helpers import colored from tinygrad.ops import ScheduleItem, BufferOps, LoadOps -from tinygrad.device import JITRunner, Device, BufferCopy, BufferXfer, update_stats +from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats from tinygrad.buffer import Buffer from tinygrad.shape.symbolic import Variable -class CustomOp(JITRunner): +@dataclass(frozen=True) +class ExecItem: + prg: Runner + rawbufs: List[Optional[Buffer]] + def run(self, var_vals:Optional[Dict[Variable, int]]=None): + self.prg.exec([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {}) + +class CustomOp(Runner): def __init__(self, fxn): self.fxn = fxn super().__init__() def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs) -class EmptyOp(JITRunner): +class EmptyOp(Runner): def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): update_stats(colored(f"empty {rawbufs[0].size:10d} {rawbufs[0].dtype}", "yellow"), 0, 0, {}, jit, 1, device=rawbufs[0].device) -def lower_schedule_item(si:ScheduleItem) -> JITRunner: +def lower_schedule_item(si:ScheduleItem) -> Runner: assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast) assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput" @@ -27,15 +35,8 @@ def lower_schedule_item(si:ScheduleItem) -> JITRunner: if ast.op is LoadOps.EMPTY: return EmptyOp() raise RuntimeError(f"don't know how to lower {ast}") -def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]] = None): - while len(schedule): - si = schedule.pop(0) +def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]: + while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.outputs+si.inputs)) - # get the program - prg = lower_schedule_item(si) - - # allocate output buffers - for out in si.outputs: out.ensure_allocated() - - # run the function (put it in JIT) - prg.exec(list(si.outputs+si.inputs), var_vals if var_vals is not None else {}) \ No newline at end of file +def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None): + for ei in lower_schedule(schedule): ei.run(var_vals) diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index a3ca2df9e2..50bb388d0a 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -5,11 +5,11 @@ from tinygrad.helpers import init_c_var, GraphException, getenv from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution from tinygrad.shape.symbolic import Variable -from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \ - get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals +from tinygrad.engine.realize import ExecItem +from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals class CUDAGraph(MultiDeviceJITGraph): - def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): + def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): # Check all jit items are compatible. if not all(isinstance(ji.prg, CompiledASTRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index 8ac52f3296..76bd1820ab 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -5,8 +5,8 @@ from tinygrad.buffer import Buffer, BufferOptions from tinygrad.device import Compiled, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device from tinygrad.shape.symbolic import Variable from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler -from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \ - get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals +from tinygrad.engine.realize import ExecItem +from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals import tinygrad.runtime.autogen.hsa as hsa from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL @@ -26,7 +26,7 @@ class VirtAQLQueue(AQLQueue): self.available_packet_slots -= 1 class HSAGraph(MultiDeviceJITGraph): - def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): + def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): self.jit_cache = jit_cache self.input_replace = get_input_replace(jit_cache, input_rawbuffers) self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) #type:ignore diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index 40ffe9d70f..6c5a9243e1 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -3,12 +3,13 @@ import Metal from tinygrad.dtype import dtypes from tinygrad.helpers import dedup, unwrap2, GraphException from tinygrad.device import Buffer, CompiledASTRunner, update_stats -from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims +from tinygrad.engine.realize import ExecItem +from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims from tinygrad.shape.symbolic import Variable from tinygrad.runtime.ops_metal import MetalDevice, wait_check class MetalGraph: - def __init__(self, device:MetalDevice, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): + def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException self.jit_cache = jit_cache diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 141170ca91..35c500d534 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -2,7 +2,7 @@ from __future__ import annotations import os, mmap, _posixshmem, io, functools from typing import Dict, List, Any, Optional from tinygrad.helpers import prod, OSX -from tinygrad.device import Compiled, Allocator, JITRunner, Buffer +from tinygrad.device import Compiled, Allocator, Runner, Buffer from tinygrad.ops import UnaryOps, LazyOp, BufferOps from tinygrad.shape.view import strides_for_shape @@ -32,7 +32,7 @@ class DiskAllocator(Allocator): else: dest[:] = src._buf() -class DiskRunner(JITRunner): +class DiskRunner(Runner): def __init__(self, ast:LazyOp): # two ASTs are allowed here. assert ast.op is BufferOps.STORE, "output of AST must be store"