mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
JitItem -> ExecItem (#4146)
* JitItem -> ExecItem * execitem in realize * cleaner * JITRunner -> Runner
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Tuple, TypeVar, List, Any, cast, Set
|
||||
import tinygrad.runtime.autogen.hip as hip
|
||||
from tinygrad.helpers import DEBUG, getenv, init_c_var
|
||||
from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t
|
||||
from tinygrad.device import Compiled, LRUAllocator, BufferOptions, JITRunner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
|
||||
from tinygrad.device import Compiled, LRUAllocator, BufferOptions, Runner, Device, Buffer, MallocAllocator, update_stats, Compiler, CompilerOptions
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.runtime.driver.hip_comgr import compile_hip
|
||||
|
||||
@@ -128,7 +128,7 @@ class HIPAllocator(LRUAllocator):
|
||||
hip_set_device(self.device.device)
|
||||
check(hip.hipMemcpyAsync(dest, src, sz, hip.hipMemcpyDeviceToDevice, None))
|
||||
|
||||
class HIPSyncEvent(JITRunner):
|
||||
class HIPSyncEvent(Runner):
|
||||
def __init__(self, lb):
|
||||
self.lb, self.device, self.dname = lb, cast(HIPDevice, Device[lb.device]), lb.device
|
||||
super().__init__()
|
||||
@@ -138,7 +138,7 @@ class HIPSyncEvent(JITRunner):
|
||||
check(hip.hipStreamWriteValue32(None, rawbufs[0]._buf, 1, 0))
|
||||
update_stats(colored("sync", "red"), 0, 0, {}, None, 1, jit, device=self.dname)
|
||||
|
||||
class HIPWaitEvent(JITRunner):
|
||||
class HIPWaitEvent(Runner):
|
||||
def __init__(self, device):
|
||||
self.device, self.dname = cast(HIPDevice, Device[device]), device
|
||||
super().__init__()
|
||||
|
||||
4
test/external/external_test_hsa_driver.py
vendored
4
test/external/external_test_hsa_driver.py
vendored
@@ -4,7 +4,7 @@ from tinygrad.device import Device, Buffer, BufferXfer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.runtime.driver.hsa import AQLQueue
|
||||
from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
|
||||
from tinygrad.engine.jit import JitItem
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
|
||||
def get_hsa_inc_prog(dev, inc=1):
|
||||
prg = f"""
|
||||
@@ -102,7 +102,7 @@ class TestHSADriver(unittest.TestCase):
|
||||
test_buf1.copyin(memoryview(bytearray(1*4)))
|
||||
test_buf2.copyin(memoryview(bytearray(1*4)))
|
||||
|
||||
jit_cache = [JitItem(BufferXfer(), [test_buf0, test_buf2]), JitItem(BufferXfer(), [test_buf2, test_buf1])]
|
||||
jit_cache = [ExecItem(BufferXfer(), [test_buf0, test_buf2]), ExecItem(BufferXfer(), [test_buf2, test_buf1])]
|
||||
graph = HSAGraph(jit_cache, [], {})
|
||||
|
||||
for i in range(10000):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import sys
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.device import JITRunner
|
||||
from tinygrad.device import Runner
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.helpers import Context, CI, OSX, getenv
|
||||
@@ -13,12 +13,12 @@ def derandomize_model(model):
|
||||
|
||||
def assert_jit_cache_len(fxn, expected_len):
|
||||
assert len(fxn.jit_cache) > 0
|
||||
# until we have a better way of typing the prg in JitItem
|
||||
if issubclass(type(fxn.jit_cache[0].prg), JITRunner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
|
||||
# until we have a better way of typing the prg in ExecItem
|
||||
if issubclass(type(fxn.jit_cache[0].prg), Runner) and not type(fxn.jit_cache[0].prg).__name__.endswith('Graph'):
|
||||
assert len(fxn.jit_cache) == expected_len
|
||||
else:
|
||||
assert len(fxn.jit_cache) == 1
|
||||
# until we have a better way of typing the prg in JitItem
|
||||
# until we have a better way of typing the prg in ExecItem
|
||||
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ Device = _Device()
|
||||
|
||||
# **************** base Runner + helpers ****************
|
||||
|
||||
class JITRunner:
|
||||
class Runner:
|
||||
def __init__(self):
|
||||
self.op_estimate:sint = 0
|
||||
self.mem_estimate:sint = 0
|
||||
@@ -67,7 +67,7 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option
|
||||
|
||||
# **************** Buffer / Allocator ****************
|
||||
|
||||
class BufferCopy(JITRunner):
|
||||
class BufferCopy(Runner):
|
||||
def copy(self, dest, src):
|
||||
if src.device.startswith("DISK") and hasattr(dest.allocator, 'copy_from_fd') and src.nbytes >= 4096 and hasattr(src.allocator.device, 'fd'):
|
||||
dest.allocator.copy_from_fd(dest._buf, src.allocator.device.fd, src._buf.offset, src.nbytes)
|
||||
@@ -158,7 +158,7 @@ class Compiler:
|
||||
if self.cachekey is not None: diskcache_put(self.cachekey, src, lib)
|
||||
return lib
|
||||
|
||||
class CompiledASTRunner(JITRunner):
|
||||
class CompiledASTRunner(Runner):
|
||||
def __init__(self, name:str, prg:str, dname:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None,
|
||||
variables:Optional[List[Variable]]=None, op_estimate:sint=0, mem_estimate:sint=0, precompiled:Optional[bytes]=None, outcount:int=1):
|
||||
super().__init__()
|
||||
@@ -201,7 +201,7 @@ class CompiledASTRunner(JITRunner):
|
||||
self.first_run = False
|
||||
return et
|
||||
|
||||
class MultiDeviceJITGraph(JITRunner):
|
||||
class MultiDeviceJITGraph(Runner):
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
|
||||
@@ -4,48 +4,43 @@ import functools, itertools, operator
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.helpers import DEBUG, merge_dicts, getenv, all_int, Context, GRAPH, flatten, GraphException
|
||||
from tinygrad.device import Compiled, JITRunner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device
|
||||
from tinygrad.device import Compiled, Runner, CompiledASTRunner, Buffer, BufferXfer, MultiDeviceJITGraph, Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from weakref import ref, WeakKeyDictionary
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JitItem:
|
||||
prg: JITRunner # or a graph executor like MetalGraph
|
||||
rawbufs: List[Optional[Buffer]]
|
||||
|
||||
def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[sint, int]:
|
||||
def get_jit_stats(jit_cache: List[ExecItem]) -> Tuple[sint, int]:
|
||||
return functools.reduce(operator.add, [ji.prg.op_estimate for ji in jit_cache if isinstance(ji.prg, CompiledASTRunner)], 0), \
|
||||
functools.reduce(operator.add, [ji.prg.mem_estimate for ji in jit_cache if isinstance(ji.prg, CompiledASTRunner)], 0)
|
||||
def get_input_replace(jit_cache: List[JitItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
|
||||
def get_input_replace(jit_cache: List[ExecItem], input_rawbuffers:List[Buffer]) -> Dict[Tuple[int, int], int]:
|
||||
input_replace: Dict[Tuple[int, int], int] = {}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in input_rawbuffers:
|
||||
input_replace[(j,i)] = input_rawbuffers.index(a)
|
||||
return input_replace
|
||||
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
|
||||
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[ExecItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ((ji.prg.global_size and not all_int(ji.prg.global_size)) or (ji.prg.local_size and not all_int(ji.prg.local_size)))] # noqa: E501
|
||||
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
|
||||
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[ExecItem]) -> List[int]:
|
||||
return [j for j,ji in enumerate(jit_cache) if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars]
|
||||
|
||||
def apply_graph_to_jit(jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[JitItem]:
|
||||
def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]) -> List[ExecItem]:
|
||||
# Split JIT cache into batches for faster graph execution.
|
||||
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
max_batch_size = getenv("JIT_BATCH_SIZE", 32)
|
||||
graphed_jit_cache: List[JitItem] = []
|
||||
current_batch: List[JitItem] = []
|
||||
graphed_jit_cache: List[ExecItem] = []
|
||||
current_batch: List[ExecItem] = []
|
||||
current_device: Optional[Compiled] = None
|
||||
|
||||
def flush_batch():
|
||||
nonlocal current_batch, current_device, max_batch_size
|
||||
try:
|
||||
if len(current_batch) <= 1 or current_device is None: raise GraphException("only one kernel doesn't graph")
|
||||
graphed_jit_cache.append(JitItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
|
||||
graphed_jit_cache.append(ExecItem(current_device.graph(current_batch, input_rawbuffers, var_vals), cast(List[Optional[Buffer]], input_rawbuffers))) # noqa: E501
|
||||
max_batch_size *= 2
|
||||
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
|
||||
except GraphException as e:
|
||||
@@ -82,7 +77,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.jit_cache: List[JitItem] = []
|
||||
self.jit_cache: List[ExecItem] = []
|
||||
self.input_replace: Dict[Tuple[int, int], int] = {}
|
||||
self.cnt: int = 0
|
||||
self.ret: Optional[ReturnType] = None
|
||||
@@ -162,7 +157,7 @@ class PlaceHolder:
|
||||
|
||||
class _CacheCollector:
|
||||
def __init__(self):
|
||||
self.cache: Optional[List[Tuple[JITRunner, List[Union[Buffer, PlaceHolder]]]]] = None
|
||||
self.cache: Optional[List[Tuple[Runner, List[Union[Buffer, PlaceHolder]]]]] = None
|
||||
|
||||
def start(self, var_vals:Optional[Dict[Variable, int]]=None):
|
||||
self.cache = []
|
||||
@@ -179,9 +174,9 @@ class _CacheCollector:
|
||||
|
||||
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, Buffer) else x for x in rawbufs]))
|
||||
|
||||
def finish(self) -> List[JitItem]:
|
||||
def finish(self) -> List[ExecItem]:
|
||||
if self.cache is None: return []
|
||||
buffer_cache: Dict[PlaceHolder, Buffer] = {}
|
||||
saved_cache, self.cache = self.cache, None
|
||||
return [JitItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
|
||||
return [ExecItem(prg, [x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x for x in pl]) for prg, pl in saved_cache]
|
||||
CacheCollector = _CacheCollector()
|
||||
|
||||
@@ -1,21 +1,29 @@
|
||||
from typing import List, Dict, Optional
|
||||
from typing import List, Dict, Optional, cast, Generator
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import colored
|
||||
from tinygrad.ops import ScheduleItem, BufferOps, LoadOps
|
||||
from tinygrad.device import JITRunner, Device, BufferCopy, BufferXfer, update_stats
|
||||
from tinygrad.device import Runner, Device, BufferCopy, BufferXfer, update_stats
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
class CustomOp(JITRunner):
|
||||
@dataclass(frozen=True)
|
||||
class ExecItem:
|
||||
prg: Runner
|
||||
rawbufs: List[Optional[Buffer]]
|
||||
def run(self, var_vals:Optional[Dict[Variable, int]]=None):
|
||||
self.prg.exec([cast(Buffer, x).ensure_allocated() for x in self.rawbufs], var_vals if var_vals is not None else {})
|
||||
|
||||
class CustomOp(Runner):
|
||||
def __init__(self, fxn):
|
||||
self.fxn = fxn
|
||||
super().__init__()
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
|
||||
|
||||
class EmptyOp(JITRunner):
|
||||
class EmptyOp(Runner):
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
|
||||
update_stats(colored(f"empty {rawbufs[0].size:10d} {rawbufs[0].dtype}", "yellow"), 0, 0, {}, jit, 1, device=rawbufs[0].device)
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> JITRunner:
|
||||
def lower_schedule_item(si:ScheduleItem) -> Runner:
|
||||
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
|
||||
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
|
||||
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
|
||||
@@ -27,15 +35,8 @@ def lower_schedule_item(si:ScheduleItem) -> JITRunner:
|
||||
if ast.op is LoadOps.EMPTY: return EmptyOp()
|
||||
raise RuntimeError(f"don't know how to lower {ast}")
|
||||
|
||||
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]] = None):
|
||||
while len(schedule):
|
||||
si = schedule.pop(0)
|
||||
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
||||
while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.outputs+si.inputs))
|
||||
|
||||
# get the program
|
||||
prg = lower_schedule_item(si)
|
||||
|
||||
# allocate output buffers
|
||||
for out in si.outputs: out.ensure_allocated()
|
||||
|
||||
# run the function (put it in JIT)
|
||||
prg.exec(list(si.outputs+si.inputs), var_vals if var_vals is not None else {})
|
||||
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
|
||||
for ei in lower_schedule(schedule): ei.run(var_vals)
|
||||
|
||||
@@ -5,11 +5,11 @@ from tinygrad.helpers import init_c_var, GraphException, getenv
|
||||
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
|
||||
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
|
||||
class CUDAGraph(MultiDeviceJITGraph):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
# Check all jit items are compatible.
|
||||
if not all(isinstance(ji.prg, CompiledASTRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, \
|
||||
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.runtime.driver.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
|
||||
|
||||
@@ -26,7 +26,7 @@ class VirtAQLQueue(AQLQueue):
|
||||
self.available_packet_slots -= 1
|
||||
|
||||
class HSAGraph(MultiDeviceJITGraph):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) #type:ignore
|
||||
|
||||
@@ -3,12 +3,13 @@ import Metal
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, unwrap2, GraphException
|
||||
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
|
||||
from tinygrad.engine.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_metal import MetalDevice, wait_check
|
||||
|
||||
class MetalGraph:
|
||||
def __init__(self, device:MetalDevice, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import os, mmap, _posixshmem, io, functools
|
||||
from typing import Dict, List, Any, Optional
|
||||
from tinygrad.helpers import prod, OSX
|
||||
from tinygrad.device import Compiled, Allocator, JITRunner, Buffer
|
||||
from tinygrad.device import Compiled, Allocator, Runner, Buffer
|
||||
from tinygrad.ops import UnaryOps, LazyOp, BufferOps
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
|
||||
@@ -32,7 +32,7 @@ class DiskAllocator(Allocator):
|
||||
else:
|
||||
dest[:] = src._buf()
|
||||
|
||||
class DiskRunner(JITRunner):
|
||||
class DiskRunner(Runner):
|
||||
def __init__(self, ast:LazyOp):
|
||||
# two ASTs are allowed here.
|
||||
assert ast.op is BufferOps.STORE, "output of AST must be store"
|
||||
|
||||
Reference in New Issue
Block a user