diff --git a/test/amd/test_sqtt_profiler.py b/test/amd/test_sqtt_profiler.py index 2cbf884b2b..62a15a94b5 100644 --- a/test/amd/test_sqtt_profiler.py +++ b/test/amd/test_sqtt_profiler.py @@ -71,7 +71,7 @@ class TestSQTTProfiler(unittest.TestCase): for i,s in enumerate(sqtt[1:], start=1): self.assertEqual(s["name"], f"{kernel_name} n{i+1}") # TODO: can we trace SQTT for graphed kernels? - def test_jit_graph(self, kernel_count=3*2): + def test_jit_graph(self, kernel_count=3*1): @TinyJit def f(a): return ((a + 1).contiguous() + 2).contiguous().sum() t = Tensor.empty(32) diff --git a/test/backend/test_graph.py b/test/backend/test_graph.py index 94448714cc..3c6485be5d 100644 --- a/test/backend/test_graph.py +++ b/test/backend/test_graph.py @@ -82,7 +82,7 @@ def helper_test_graphs(graph_impl, graphs, runs=RUN_CNT): ground_truth_np = [np.frombuffer(x, _to_np_dtype(bufs[i].dtype)) for i,x in enumerate(ground_thruth_bufs)] # Build graphs - gr_ji = [ExecItem(UOp(Ops.NOOP), [], prg=graph_impl(graph, [], {})) for graph in graphs] + gr_ji = [ExecItem(UOp(Ops.NOOP), [], prg=graph_impl(None, None, graph)) for graph in graphs] for _ in range(runs): test_bufs = helper_run_jit(gr_ji, bufs, out_buffers) diff --git a/test/backend/test_jit.py b/test/backend/test_jit.py index a768b641fb..0ce15803a3 100644 --- a/test/backend/test_jit.py +++ b/test/backend/test_jit.py @@ -577,7 +577,7 @@ class TestJitPrune(unittest.TestCase): a = Tensor.rand(16).realize() out = w2_noprune(a) np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())]) - assert len(w2_noprune.captured.jit_cache) == 2 + assert_jit_cache_len(w2_noprune, 2) for _ in range(3): a = Tensor.rand(16).realize() diff --git a/test/backend/test_profiler.py b/test/backend/test_profiler.py index 34525a46b3..c42aa443c4 100644 --- a/test/backend/test_profiler.py +++ b/test/backend/test_profiler.py @@ -134,7 +134,7 @@ class TestProfiler(unittest.TestCase): _, _ = helper_profile_filter_device(profile, TestProfiler.d0.device) _, _ = helper_profile_filter_device(profile, d1.device) - assert len(graph_evs) == 1, "one graph event is expected" + assert len(graph_evs) == 2, "2 graph events are expected" assert len(graph_evs[0].ents) == 2, "two entities are expected" @unittest.skipIf(CI or not issubclass(type(Device[Device.DEFAULT]), HCQCompiled), "skip CI") diff --git a/test/null/test_attention.py b/test/null/test_attention.py index 9e6f933157..39f93c5bce 100644 --- a/test/null/test_attention.py +++ b/test/null/test_attention.py @@ -1,6 +1,7 @@ import unittest from tinygrad import Tensor, dtypes, TinyJit, UOp from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis +from test.helpers import assert_jit_cache_len def apply_rope(x:Tensor, start_pos:int): B, H, T, Hd = x.shape @@ -28,12 +29,8 @@ class TestAttention(unittest.TestCase): for _ in range(3): rope_noprune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1)) rope_prune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1)) - noprune_size = len(rope_noprune.captured.jit_cache) - prune_size = len(rope_prune.captured.jit_cache) - - self.assertGreater(noprune_size, prune_size) - self.assertGreaterEqual(noprune_size, 2) - self.assertEqual(prune_size, 1) + assert_jit_cache_len(rope_prune, 1) + assert_jit_cache_len(rope_noprune, 3) if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_hcq_graph.py b/test/unit/test_hcq_graph.py index 5fcb20578c..9f7fb8d472 100644 --- a/test/unit/test_hcq_graph.py +++ b/test/unit/test_hcq_graph.py @@ -1,7 +1,8 @@ import unittest from tinygrad import Device, Tensor from tinygrad.engine.jit import TinyJit -from tinygrad.engine.realize import CompiledRunner +from tinygrad.uop.ops import UOp, Ops +from tinygrad.dtype import dtypes from tinygrad.runtime.graph.hcq import HCQGraph from tinygrad.runtime.support.hcq import HCQCompiled from tinygrad.runtime.support.usb import USBMMIOInterface @@ -19,27 +20,23 @@ class TestHCQUnit(unittest.TestCase): inp, inp_cpu = Tensor.randn(10, 10, device=Device.DEFAULT).realize(), Tensor.randn(10, 10, device="CPU").realize() for _ in range(5): f(inp, inp_cpu) - gpu_ei, cpu_ei, gpu_devs = None, None, [] - for ji in f.captured.jit_cache: - if isinstance(ji.prg, CompiledRunner): - if ji.prg.dev._is_cpu(): cpu_ei = ji - else: - gpu_ei = ji - if ji.prg.dev not in gpu_devs: gpu_devs.append(ji.prg.dev) - assert gpu_ei is not None and cpu_ei is not None and len(gpu_devs) > 0 + # construct minimal CALL UOps for supports_exec_item + gpu_call = UOp(Ops.SINK).call(UOp.new_buffer(Device.DEFAULT, 1, dtypes.float)) + cpu_call = UOp(Ops.SINK).call(UOp.new_buffer("CPU", 1, dtypes.float)) + gpu_devs = [d0] # local MMIO: GPU works alone and with CPU in batch (cpu_support=True) - assert HCQGraph.supports_exec_item(gpu_devs, gpu_ei) is True - assert HCQGraph.supports_exec_item(gpu_devs, cpu_ei) is True - assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_ei) is True + assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True + assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is True + assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is True # USB MMIO: GPU-only still works, but CPU batching must be rejected (cpu_support=False) orig_view = d0.timeline_signal.base_buf.view try: d0.timeline_signal.base_buf.view = USBMMIOInterface(MockUSB(bytearray(256)), 0, 16, fmt='B') - assert HCQGraph.supports_exec_item(gpu_devs, gpu_ei) is True - assert HCQGraph.supports_exec_item(gpu_devs, cpu_ei) is False - assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_ei) is False + assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True + assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is False + assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is False finally: d0.timeline_signal.base_buf.view = orig_view diff --git a/test/unit/test_metal_graph.py b/test/unit/test_metal_graph.py index affd20fec4..74c733ba90 100644 --- a/test/unit/test_metal_graph.py +++ b/test/unit/test_metal_graph.py @@ -1,34 +1,44 @@ import unittest from unittest.mock import MagicMock from tinygrad import Device -from tinygrad.engine.realize import CompiledRunner +from tinygrad.uop.ops import Ops +from tinygrad.dtype import dtypes @unittest.skipUnless(Device.DEFAULT == "METAL", "Metal device required to run") class TestMetalGraph(unittest.TestCase): def setUp(self): from tinygrad.runtime.graph.metal import MetalGraph - from tinygrad.runtime.ops_metal import MetalBuffer self.MetalGraph = MetalGraph - self.MetalBuffer = MetalBuffer self.dev = Device[Device.DEFAULT] - def metal_buf(self, offset): return MagicMock(_buf=self.MetalBuffer(MagicMock(), 4, offset)) + def metal_buf(self, offset): + buf = MagicMock() + if offset > 0: + buf.op = Ops.BUFFER_VIEW + buf.arg = (None, offset) + buf.dtype = dtypes.uint8 + else: + buf.op = Ops.BUFFER + buf.device = Device.DEFAULT + return buf - def ei(self, *bufs): - ei = MagicMock() - ei.prg = MagicMock(spec=CompiledRunner) - ei.bufs = list(bufs) - return ei + def call(self, *bufs): + c = MagicMock() + c.src = (MagicMock(op=Ops.PROGRAM),) + tuple(bufs) + return c def test_supports_exec_item_normal_offset(self): - assert self.MetalGraph.supports_exec_item([self.dev], self.ei(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True + assert self.MetalGraph.supports_exec_item([self.dev], self.call(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True def test_supports_exec_item_overflow_offset(self): - assert self.MetalGraph.supports_exec_item([self.dev], self.ei(self.metal_buf(0), self.metal_buf(0x100000000))) is False + assert self.MetalGraph.supports_exec_item([self.dev], self.call(self.metal_buf(0), self.metal_buf(0x100000000))) is False def test_supports_exec_item_nonmetal_buf(self): - # HCQBuffer.offset is a method, not an int — must not crash - self.MetalGraph.supports_exec_item([self.dev], self.ei(MagicMock(**{"_buf.offset": lambda: 0}))) + # non-BUFFER_VIEW ops should not be checked for offset + buf = MagicMock() + buf.op = Ops.BUFFER + buf.device = Device.DEFAULT + self.MetalGraph.supports_exec_item([self.dev], self.call(buf)) if __name__ == "__main__": unittest.main() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index cc724213be..d302392433 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -1,11 +1,11 @@ from typing import TypeVar, Generic, Callable, cast, Any import functools, collections from tinygrad.tensor import Tensor -from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap, pluralize +from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap, pluralize, VIZ from tinygrad.device import Buffer, Compiled, Device, MultiBuffer -from tinygrad.dtype import DType -from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops, buffers, track_rewrites -from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates +from tinygrad.dtype import DType, dtypes +from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite +from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates from tinygrad.engine.memory import memory_plan_rewrite, _collect_bufs from tinygrad.engine.schedule import linear_to_schedule from tinygrad.nn.state import get_parameters @@ -22,9 +22,57 @@ def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]: else: onetime.append(si) return linear.replace(src=tuple(kept)), linear.replace(src=tuple(onetime)) -@track_rewrites(lambda linear,held_bufs,ret: f"JIT {pluralize('Kernel', len(ret))}") -def jit_lower(linear:UOp, held_bufs:set[UOp]) -> list[ExecItem]: - return [ei.lower() for ei in linear_to_schedule(memory_plan_rewrite(linear, held_bufs))] +def create_graph_call(batch:list[UOp], input_buffers:set[Buffer]) -> UOp: + def bufs_for(b): return b.buffer.bufs if isinstance(b.buffer, MultiBuffer) else [b.buffer] + + input_list = dedup(b for si in batch for b in si.src[1:] if b.op in (Ops.BUFFER, Ops.BUFFER_VIEW) and not input_buffers.isdisjoint(bufs_for(b))) + cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, src=tuple(batch)), *input_list), arg="graph") + return cf.call(*input_list, metadata=tuple(m for si in batch for m in si.arg.metadata)) + +def graph_split_rewrite(linear:UOp, input_buffers:set[Buffer], max_batch_size:int=0) -> UOp: + new_src: list[UOp] = [] + current_batch: list[UOp] = [] + current_batch_devs: list[Compiled] = [] + + def flush_batch(): + nonlocal current_batch, current_batch_devs, max_batch_size, new_src + if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): new_src.extend(current_batch) + else: + new_src.append(create_graph_call(current_batch, input_buffers)) + max_batch_size *= 2 + if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels") + current_batch, current_batch_devs = [], [] + + for si in linear.src: + if si.src[0].op is Ops.BUFFER_VIEW: continue + + devs = [Device[x] for x in (si.device if isinstance(si.device, tuple) else (si.device,))] + graph_t = graph_class(devs[0]) if devs[0].graph is not None else None + + can_graph = graph_t is not None and graph_t.supports_exec_item(devs, si) + can_extend = can_graph and graph_t is not None and (not current_batch_devs or graph_t.supports_exec_item(current_batch_devs, si)) \ + and (max_batch_size == 0 or len(current_batch) < max_batch_size) + if not can_extend and current_batch: flush_batch() + + # append this si and update devs + (current_batch if can_graph else new_src).append(si) + current_batch_devs = dedup(current_batch_devs + devs) if can_graph else [] + if current_batch: flush_batch() + return linear.replace(src=tuple(new_src)) + +def jit_cache_bufs(jit_cache:list[ExecItem]): + for ei in jit_cache: + for b in ei.bufs: + if b is not None: yield b + if isinstance(ei.prg, GraphRunner): yield from jit_cache_bufs(ei.prg.jit_cache) + +@track_rewrites(lambda linear,held_bufs,input_buffers=None,ret=(): f"JIT {pluralize('call', len(linear.src))}") +def jit_lower(linear:UOp, held_bufs:set[UOp], input_buffers:list[Buffer]|None=None) -> list[ExecItem]: + if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View captured linear") + linear = memory_plan_rewrite(linear, held_bufs) + if JIT < 2: linear = graph_split_rewrite(linear, set(input_buffers or []), max_batch_size=JIT_BATCH_SIZE.value) + if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View graphed linear") + return [ei.lower() for ei in linear_to_schedule(linear)] class GraphException(Exception): pass class JitError(Exception): pass @@ -38,55 +86,6 @@ def _check_no_non_tensor_return(ret): def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph -def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], - orig_valid_positions: dict[int, set[int]]|None = None, max_batch_size=0) -> list[ExecItem]: - # Split JIT cache into batches for faster graph execution. - # This allows the accelerator to run some batches while subsequent graphs are still being updated. - graphed_jit_cache: list[ExecItem] = [] - current_batch: list[ExecItem] = [] - current_batch_devs: list[Compiled] = [] - - def flush_batch(): - nonlocal current_batch, current_batch_devs, max_batch_size - try: - if len(current_batch_devs) == 0: raise GraphException("no device for graph") - if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph") - graph_runner = current_batch_devs[0].graph(current_batch, input_buffers, var_vals, orig_valid_positions=orig_valid_positions) - # clear jit inputs to allow their memory to be freed/reused - for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None - graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_buffers), prg=graph_runner)) - max_batch_size *= 2 - if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}") - except GraphException as e: - graphed_jit_cache.extend(current_batch) - if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_batch_devs[0]}: {e}") - current_batch = [] - current_batch_devs = [] - - for ji in jit_cache: - match ji.prg: - case CompiledRunner(): ji_graph_dev = ji.prg.dev - case BufferXfer(): ji_graph_dev = Device[unwrap(ji.bufs[0]).device] - case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device != "CPU"), None) - case ViewOp(): continue # ViewOps are just ignored - case _: ji_graph_dev = None # Everything else is not graphed and flushes existing graph if it's being constructed - - # Check if this jit item can be graphed at all, so check if a new graph supports the current item. - can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji) - - # Check if the current batch can be extended with this item. - can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and \ - graph_class(current_batch_devs[0]).supports_exec_item(dedup(current_batch_devs + [ji_graph_dev]), ji) - can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size) - - # Flush the current batch if any, since it can't be extended or is full. - if not can_extend_graph_batch and len(current_batch) > 0: flush_batch() - (current_batch if can_be_graphed else graphed_jit_cache).append(ji) - current_batch_devs = dedup(current_batch_devs + [ji_graph_dev]) if can_be_graphed else [] - - if len(current_batch) > 0: flush_batch() - return graphed_jit_cache - def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer], orig_valid_positions: dict[int, set[int]]|None = None) -> dict[tuple[int, int], int]: input_replace: dict[tuple[int, int], int] = {} @@ -99,23 +98,30 @@ def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer], return input_replace class GraphRunner(Runner): - def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int], - orig_valid_positions: dict[int, set[int]]|None = None): - self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph - self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_buffers, orig_valid_positions) + def __init__(self, linear:UOp|None, input_buffers:list[Buffer]|None, + jit_cache:list[ExecItem]|None=None, input_replace:dict[tuple[int,int],int]|None=None): + # TODO: captured jit as linear? + if linear is not None: + jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])] + for b in jit_cache_bufs(jit_cache): b.ensure_allocated() + input_replace = get_input_replace(jit_cache, input_buffers) if input_buffers else {} + self.jit_cache, self.input_replace = unwrap(jit_cache), input_replace or {} + self.var_vals_replace:dict[int, list[tuple[int, int]]] = {} self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {} self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {} def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim) - self.vars = sorted(var_vals.keys()) - self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] + - [tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)]) + crs = [(ji, ji.prg) for ji in self.jit_cache if isinstance(ji.prg, CompiledRunner)] + self.vars = sorted({v.expr for ji,p in crs for v in p.p.vars if v.expr not in ji.fixedvars | p.p.runtimevars}) + self.symbolic_dims = dedup([tuple(d) for _,p in crs if (d:=p.p.local_size) and is_sym_dim(d)] + + [tuple(d) for _,p in crs if (d:=p.p.global_size) and is_sym_dim(d)]) + def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None estimates = Estimates() - for j,ji in enumerate(jit_cache): + for j,ji in enumerate(self.jit_cache): assert ji.prg is not None estimates += ji.prg.estimates if isinstance(ji.prg, CompiledRunner): @@ -132,8 +138,10 @@ class GraphRunner(Runner): self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list) self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list) - assert jit_cache[0].prg is not None - super().__init__(colored(f"", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify()) + assert self.jit_cache[0].prg is not None + super().__init__(colored(f"", "cyan"), self.jit_cache[0].prg.device.split(":")[0], estimates.simplify()) + + def __reduce__(self): return self.__class__, (None, None, self.jit_cache, self.input_replace) def updated_vars(self, var_vals: dict[str, int]): vals = [var_vals[v] for v in self.vars] @@ -165,18 +173,25 @@ class GraphRunner(Runner): return list({id(x):x for x in wait_nodes}.values()) @staticmethod - def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool: return isinstance(ei.prg, CompiledRunner) and len(dedup(devs)) == 1 + def _all_devs(batch_devs:list[Compiled], new_call:UOp) -> list[Compiled]: + return dedup(batch_devs + [Device[x] for b in new_call.src[1:] if b.op is not Ops.BIND + for x in (b.device if isinstance(b.device, tuple) else (b.device,))]) + + @staticmethod + def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool: + return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM) and len(GraphRunner._all_devs(batch_devs, new_call)) == 1 # a marker for your graph supporting multiple devices of the same type class MultiGraphRunner(GraphRunner): @staticmethod - def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool: + def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool: # Devices must be the same type - return isinstance(ei.prg, (CompiledRunner, BufferXfer)) and len(dedup([type(Device[b.device]) for b in ei.bufs if b]+[type(d) for d in devs]))==1 + return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1 def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]: if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins] if isinstance(ei.prg, (BufferCopy, BufferXfer, EncDec)): return [cast(Buffer, ei.bufs[0])] + if isinstance(ei.prg, GraphRunner): return dedup([b for inner in ei.prg.jit_cache for b in get_out_buffers_for_ei(inner)]) return [] def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]): @@ -201,6 +216,7 @@ class CapturedJit(Generic[ReturnType]): self._jit_cache: list[ExecItem] = self.jit_cache self._input_replace: dict[tuple[int, int], int] = self.input_replace self._first_run = True + self._needs_rebuild = False # precompute read-after-write hazard detection self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)} self._input_to_max_reader: dict[int, int] = {} @@ -217,12 +233,13 @@ class CapturedJit(Generic[ReturnType]): depends: set[Buffer|None] = set([None]) update_depends(depends, self.jit_cache) arenas = {b._base for b in depends if b is not None and b._base is not None} - to_free = {b for b in depends if b is not None} | {b for ei in self.jit_cache for b in ei.bufs if b is not None and b._base in arenas} + to_free = {b for b in depends if b is not None} | {b for b in jit_cache_bufs(self.jit_cache) if b._base in arenas} for b in to_free: if hasattr(b, '_buf'): b.deallocate() for a in arenas: if a.allocated_views == 0 and a.is_allocated(): a.deallocate() self.__post_init__() + self._needs_rebuild = True # jit exec def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType: @@ -237,23 +254,13 @@ class CapturedJit(Generic[ReturnType]): for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx] - # Condense the items into a graph executor. + # allocate intermediates if freed on first run if self._first_run: - # allocate intermediates if freed - for ji in self.jit_cache: - for b in ji.bufs: - if b is not None: b.ensure_allocated() - # create graph if needed - if JIT < 2: - # build a map from ExecItem object to the buffer positions that are valid inputs (from original input_replace) - orig_valid_positions: dict[int, set[int]] = {} # id(ExecItem) -> set of valid buffer indices - for (j, i) in self.input_replace: orig_valid_positions.setdefault(id(self.jit_cache[j]), set()).add(i) - self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, orig_valid_positions, max_batch_size=JIT_BATCH_SIZE.value) - # recompute input_replace: GraphRunner items have all positions valid, non-GraphRunner items use orig_valid_positions - valid_positions = {id(ji): set(range(len(ji.bufs))) if isinstance(ji.prg, GraphRunner) else orig_valid_positions.get(id(ji), set()) - for ji in self._jit_cache} - self._input_replace = get_input_replace(self._jit_cache, input_buffers, valid_positions) - self._first_run = False + for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated() + if self._needs_rebuild: + for ei in self.jit_cache: + if isinstance(ei.prg, GraphRunner): ei.prg = type(ei.prg)(None, None, ei.prg.jit_cache, ei.prg.input_replace) + self._first_run = self._needs_rebuild = False if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels") for ei in self._jit_cache: ei.run(var_vals, jit=True) @@ -341,7 +348,7 @@ class TinyJit(Generic[ReturnType]): held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER} with Context(BEAM=getenv("JITBEAM", BEAM.value)): - jit_cache = jit_lower(big_linear, held_bufs) + jit_cache = jit_lower(big_linear, held_bufs, input_buffers) # track inputs that are views of buffers # TODO: eventually expected_buffers should live in ExecItem diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index c7675d19f7..7e1dd553ee 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -131,6 +131,7 @@ si_lowerer = PatternMatcher([ if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \ else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))), (UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="cf"), lambda ctx,cf: EncDec(cf, ctx[0].nbytes, ctx[0].device)), + (UPat(Ops.CUSTOM_FUNCTION, arg="graph", name="cf"), lambda ctx,cf: Device[cf.device if isinstance(cf.device,str) else cf.device[0]].graph(cf, ctx)) ]) @dataclass diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 54043ed334..4b553c7a77 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -74,7 +74,9 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]: buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize) ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND] metadata = si.arg.metadata - if any(isinstance(x, MultiBuffer) for x in ubufs): + if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph": + schedule.append(ExecItem(ast, flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in ubufs]), metadata)) + elif any(isinstance(x, MultiBuffer) for x in ubufs): assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" dnums = [x for x in ast.variables() if x.expr == '_device_num'] for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index ae91f3ce30..e2a9ab320f 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -1,12 +1,12 @@ import collections, time from typing import Any, cast -from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing, TracingKey +from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, suppress_finalizing, TracingKey from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent from tinygrad.dtype import dtypes -from tinygrad.uop.ops import UOp, Variable -from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, BufferCopy -from tinygrad.engine.jit import MultiGraphRunner +from tinygrad.uop.ops import UOp, Ops, Variable +from tinygrad.engine.realize import BufferXfer, CompiledRunner, BufferCopy +from tinygrad.engine.jit import GraphRunner, MultiGraphRunner class HCQGraph(MultiGraphRunner): def __init__(self, *args, **kwargs): @@ -239,9 +239,9 @@ class HCQGraph(MultiGraphRunner): for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True)) @staticmethod - def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool: + def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool: # Check if all devices are HCQ - all_devs = cast(list[HCQCompiled], dedup(devs + [Device[b.device] for b in ei.bufs if b])) + all_devs = cast(list[HCQCompiled], GraphRunner._all_devs(batch_devs, new_call)) if not all(issubclass(type(d), HCQCompiled) for d in all_devs): return False # If all of devices are mapped into CPU address space, can use CPU inside the peer group. @@ -250,6 +250,8 @@ class HCQGraph(MultiGraphRunner): # Check if all devices are within the same peer group. If CPU is supported, don't count it as a separate peer group. if len(set(d.peer_group for d in all_devs if not (cpu_support and d._is_cpu()))) > 1: return False - # MOCKGPU is not supported, since it can't execute commands in parallel - copy = (isinstance(ei.prg, BufferCopy) and cast(HCQCompiled, devs[0]).hw_copy_queue_t is not None) and not getenv("MOCKGPU") - return isinstance(ei.prg, (CompiledRunner, BufferXfer)) or copy + if new_call.src[0].op is Ops.COPY: + # MOCKGPU is not supported, since it can't execute commands in parallel + is_xfer = len(set(type(d) for d in all_devs)) == 1 and hasattr(alc:=all_devs[0].allocator, '_transfer') and alc.supports_transfer + return is_xfer or (all_devs[0].hw_copy_queue_t is not None and not getenv("MOCKGPU")) + return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM) diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index 815338b4ad..2994a2b390 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -3,9 +3,10 @@ import ctypes, re, decimal from tinygrad.dtype import dtypes from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent -from tinygrad.engine.realize import ExecItem, CompiledRunner +from tinygrad.uop.ops import UOp, Ops +from tinygrad.engine.realize import CompiledRunner from tinygrad.engine.jit import GraphRunner, GraphException -from tinygrad.runtime.ops_metal import wait_check, to_ns_str, MetalBuffer +from tinygrad.runtime.ops_metal import wait_check, to_ns_str from tinygrad.runtime.autogen import metal from tinygrad.runtime.support import objc @@ -109,7 +110,7 @@ class MetalGraph(GraphRunner): self.collect_timestamps() @staticmethod - def supports_exec_item(devs, ei:ExecItem) -> bool: + def supports_exec_item(batch_devs, new_call:UOp) -> bool: # Metal ICB replay encodes offsets as uint32; reject if any Metal buffer offset exceeds 32-bit range. - if any(b is not None and isinstance(b._buf, MetalBuffer) and b._buf.offset > 0xFFFFFFFF for b in ei.bufs): return False - return GraphRunner.supports_exec_item(devs, ei) + if any(b.op is Ops.BUFFER_VIEW and b.arg[1] * b.dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False + return GraphRunner.supports_exec_item(batch_devs, new_call)