jit: graphing in uops (#15489)

* jit: graphing as rewrite rule

* f

* +metal,cuda

* x

* cl

* x

* x

* simpler

* f

* m

* x

* revert?

* revert2

* back

* back

* t

* x

* m

* x

* c

* x

* l

* x

* comment

* smaller

* rv

* x

* x
This commit is contained in:
nimlgen
2026-03-27 19:09:02 +03:00
committed by GitHub
parent 30ebbe7f17
commit 0d6fc0f571
12 changed files with 157 additions and 140 deletions

View File

@@ -71,7 +71,7 @@ class TestSQTTProfiler(unittest.TestCase):
for i,s in enumerate(sqtt[1:], start=1): self.assertEqual(s["name"], f"{kernel_name} n{i+1}")
# TODO: can we trace SQTT for graphed kernels?
def test_jit_graph(self, kernel_count=3*2):
def test_jit_graph(self, kernel_count=3*1):
@TinyJit
def f(a): return ((a + 1).contiguous() + 2).contiguous().sum()
t = Tensor.empty(32)

View File

@@ -82,7 +82,7 @@ def helper_test_graphs(graph_impl, graphs, runs=RUN_CNT):
ground_truth_np = [np.frombuffer(x, _to_np_dtype(bufs[i].dtype)) for i,x in enumerate(ground_thruth_bufs)]
# Build graphs
gr_ji = [ExecItem(UOp(Ops.NOOP), [], prg=graph_impl(graph, [], {})) for graph in graphs]
gr_ji = [ExecItem(UOp(Ops.NOOP), [], prg=graph_impl(None, None, graph)) for graph in graphs]
for _ in range(runs):
test_bufs = helper_run_jit(gr_ji, bufs, out_buffers)

View File

@@ -577,7 +577,7 @@ class TestJitPrune(unittest.TestCase):
a = Tensor.rand(16).realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_noprune.captured.jit_cache) == 2
assert_jit_cache_len(w2_noprune, 2)
for _ in range(3):
a = Tensor.rand(16).realize()

View File

@@ -134,7 +134,7 @@ class TestProfiler(unittest.TestCase):
_, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
_, _ = helper_profile_filter_device(profile, d1.device)
assert len(graph_evs) == 1, "one graph event is expected"
assert len(graph_evs) == 2, "2 graph events are expected"
assert len(graph_evs[0].ents) == 2, "two entities are expected"
@unittest.skipIf(CI or not issubclass(type(Device[Device.DEFAULT]), HCQCompiled), "skip CI")

View File

@@ -1,6 +1,7 @@
import unittest
from tinygrad import Tensor, dtypes, TinyJit, UOp
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
from test.helpers import assert_jit_cache_len
def apply_rope(x:Tensor, start_pos:int):
B, H, T, Hd = x.shape
@@ -28,12 +29,8 @@ class TestAttention(unittest.TestCase):
for _ in range(3):
rope_noprune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
rope_prune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
noprune_size = len(rope_noprune.captured.jit_cache)
prune_size = len(rope_prune.captured.jit_cache)
self.assertGreater(noprune_size, prune_size)
self.assertGreaterEqual(noprune_size, 2)
self.assertEqual(prune_size, 1)
assert_jit_cache_len(rope_prune, 1)
assert_jit_cache_len(rope_noprune, 3)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,7 +1,8 @@
import unittest
from tinygrad import Device, Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner
from tinygrad.uop.ops import UOp, Ops
from tinygrad.dtype import dtypes
from tinygrad.runtime.graph.hcq import HCQGraph
from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.runtime.support.usb import USBMMIOInterface
@@ -19,27 +20,23 @@ class TestHCQUnit(unittest.TestCase):
inp, inp_cpu = Tensor.randn(10, 10, device=Device.DEFAULT).realize(), Tensor.randn(10, 10, device="CPU").realize()
for _ in range(5): f(inp, inp_cpu)
gpu_ei, cpu_ei, gpu_devs = None, None, []
for ji in f.captured.jit_cache:
if isinstance(ji.prg, CompiledRunner):
if ji.prg.dev._is_cpu(): cpu_ei = ji
else:
gpu_ei = ji
if ji.prg.dev not in gpu_devs: gpu_devs.append(ji.prg.dev)
assert gpu_ei is not None and cpu_ei is not None and len(gpu_devs) > 0
# construct minimal CALL UOps for supports_exec_item
gpu_call = UOp(Ops.SINK).call(UOp.new_buffer(Device.DEFAULT, 1, dtypes.float))
cpu_call = UOp(Ops.SINK).call(UOp.new_buffer("CPU", 1, dtypes.float))
gpu_devs = [d0]
# local MMIO: GPU works alone and with CPU in batch (cpu_support=True)
assert HCQGraph.supports_exec_item(gpu_devs, gpu_ei) is True
assert HCQGraph.supports_exec_item(gpu_devs, cpu_ei) is True
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_ei) is True
assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True
assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is True
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is True
# USB MMIO: GPU-only still works, but CPU batching must be rejected (cpu_support=False)
orig_view = d0.timeline_signal.base_buf.view
try:
d0.timeline_signal.base_buf.view = USBMMIOInterface(MockUSB(bytearray(256)), 0, 16, fmt='B')
assert HCQGraph.supports_exec_item(gpu_devs, gpu_ei) is True
assert HCQGraph.supports_exec_item(gpu_devs, cpu_ei) is False
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_ei) is False
assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True
assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is False
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is False
finally:
d0.timeline_signal.base_buf.view = orig_view

View File

@@ -1,34 +1,44 @@
import unittest
from unittest.mock import MagicMock
from tinygrad import Device
from tinygrad.engine.realize import CompiledRunner
from tinygrad.uop.ops import Ops
from tinygrad.dtype import dtypes
@unittest.skipUnless(Device.DEFAULT == "METAL", "Metal device required to run")
class TestMetalGraph(unittest.TestCase):
def setUp(self):
from tinygrad.runtime.graph.metal import MetalGraph
from tinygrad.runtime.ops_metal import MetalBuffer
self.MetalGraph = MetalGraph
self.MetalBuffer = MetalBuffer
self.dev = Device[Device.DEFAULT]
def metal_buf(self, offset): return MagicMock(_buf=self.MetalBuffer(MagicMock(), 4, offset))
def metal_buf(self, offset):
buf = MagicMock()
if offset > 0:
buf.op = Ops.BUFFER_VIEW
buf.arg = (None, offset)
buf.dtype = dtypes.uint8
else:
buf.op = Ops.BUFFER
buf.device = Device.DEFAULT
return buf
def ei(self, *bufs):
ei = MagicMock()
ei.prg = MagicMock(spec=CompiledRunner)
ei.bufs = list(bufs)
return ei
def call(self, *bufs):
c = MagicMock()
c.src = (MagicMock(op=Ops.PROGRAM),) + tuple(bufs)
return c
def test_supports_exec_item_normal_offset(self):
assert self.MetalGraph.supports_exec_item([self.dev], self.ei(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True
assert self.MetalGraph.supports_exec_item([self.dev], self.call(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True
def test_supports_exec_item_overflow_offset(self):
assert self.MetalGraph.supports_exec_item([self.dev], self.ei(self.metal_buf(0), self.metal_buf(0x100000000))) is False
assert self.MetalGraph.supports_exec_item([self.dev], self.call(self.metal_buf(0), self.metal_buf(0x100000000))) is False
def test_supports_exec_item_nonmetal_buf(self):
# HCQBuffer.offset is a method, not an int — must not crash
self.MetalGraph.supports_exec_item([self.dev], self.ei(MagicMock(**{"_buf.offset": lambda: 0})))
# non-BUFFER_VIEW ops should not be checked for offset
buf = MagicMock()
buf.op = Ops.BUFFER
buf.device = Device.DEFAULT
self.MetalGraph.supports_exec_item([self.dev], self.call(buf))
if __name__ == "__main__":
unittest.main()

View File

@@ -1,11 +1,11 @@
from typing import TypeVar, Generic, Callable, cast, Any
import functools, collections
from tinygrad.tensor import Tensor
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap, pluralize
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap, pluralize, VIZ
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
from tinygrad.dtype import DType
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops, buffers, track_rewrites
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
from tinygrad.dtype import DType, dtypes
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
from tinygrad.engine.memory import memory_plan_rewrite, _collect_bufs
from tinygrad.engine.schedule import linear_to_schedule
from tinygrad.nn.state import get_parameters
@@ -22,9 +22,57 @@ def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]:
else: onetime.append(si)
return linear.replace(src=tuple(kept)), linear.replace(src=tuple(onetime))
@track_rewrites(lambda linear,held_bufs,ret: f"JIT {pluralize('Kernel', len(ret))}")
def jit_lower(linear:UOp, held_bufs:set[UOp]) -> list[ExecItem]:
return [ei.lower() for ei in linear_to_schedule(memory_plan_rewrite(linear, held_bufs))]
def create_graph_call(batch:list[UOp], input_buffers:set[Buffer]) -> UOp:
def bufs_for(b): return b.buffer.bufs if isinstance(b.buffer, MultiBuffer) else [b.buffer]
input_list = dedup(b for si in batch for b in si.src[1:] if b.op in (Ops.BUFFER, Ops.BUFFER_VIEW) and not input_buffers.isdisjoint(bufs_for(b)))
cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, src=tuple(batch)), *input_list), arg="graph")
return cf.call(*input_list, metadata=tuple(m for si in batch for m in si.arg.metadata))
def graph_split_rewrite(linear:UOp, input_buffers:set[Buffer], max_batch_size:int=0) -> UOp:
new_src: list[UOp] = []
current_batch: list[UOp] = []
current_batch_devs: list[Compiled] = []
def flush_batch():
nonlocal current_batch, current_batch_devs, max_batch_size, new_src
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): new_src.extend(current_batch)
else:
new_src.append(create_graph_call(current_batch, input_buffers))
max_batch_size *= 2
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels")
current_batch, current_batch_devs = [], []
for si in linear.src:
if si.src[0].op is Ops.BUFFER_VIEW: continue
devs = [Device[x] for x in (si.device if isinstance(si.device, tuple) else (si.device,))]
graph_t = graph_class(devs[0]) if devs[0].graph is not None else None
can_graph = graph_t is not None and graph_t.supports_exec_item(devs, si)
can_extend = can_graph and graph_t is not None and (not current_batch_devs or graph_t.supports_exec_item(current_batch_devs, si)) \
and (max_batch_size == 0 or len(current_batch) < max_batch_size)
if not can_extend and current_batch: flush_batch()
# append this si and update devs
(current_batch if can_graph else new_src).append(si)
current_batch_devs = dedup(current_batch_devs + devs) if can_graph else []
if current_batch: flush_batch()
return linear.replace(src=tuple(new_src))
def jit_cache_bufs(jit_cache:list[ExecItem]):
for ei in jit_cache:
for b in ei.bufs:
if b is not None: yield b
if isinstance(ei.prg, GraphRunner): yield from jit_cache_bufs(ei.prg.jit_cache)
@track_rewrites(lambda linear,held_bufs,input_buffers=None,ret=(): f"JIT {pluralize('call', len(linear.src))}")
def jit_lower(linear:UOp, held_bufs:set[UOp], input_buffers:list[Buffer]|None=None) -> list[ExecItem]:
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View captured linear")
linear = memory_plan_rewrite(linear, held_bufs)
if JIT < 2: linear = graph_split_rewrite(linear, set(input_buffers or []), max_batch_size=JIT_BATCH_SIZE.value)
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View graphed linear")
return [ei.lower() for ei in linear_to_schedule(linear)]
class GraphException(Exception): pass
class JitError(Exception): pass
@@ -38,55 +86,6 @@ def _check_no_non_tensor_return(ret):
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
orig_valid_positions: dict[int, set[int]]|None = None, max_batch_size=0) -> list[ExecItem]:
# Split JIT cache into batches for faster graph execution.
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
graphed_jit_cache: list[ExecItem] = []
current_batch: list[ExecItem] = []
current_batch_devs: list[Compiled] = []
def flush_batch():
nonlocal current_batch, current_batch_devs, max_batch_size
try:
if len(current_batch_devs) == 0: raise GraphException("no device for graph")
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
graph_runner = current_batch_devs[0].graph(current_batch, input_buffers, var_vals, orig_valid_positions=orig_valid_positions)
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_buffers), prg=graph_runner))
max_batch_size *= 2
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}")
except GraphException as e:
graphed_jit_cache.extend(current_batch)
if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_batch_devs[0]}: {e}")
current_batch = []
current_batch_devs = []
for ji in jit_cache:
match ji.prg:
case CompiledRunner(): ji_graph_dev = ji.prg.dev
case BufferXfer(): ji_graph_dev = Device[unwrap(ji.bufs[0]).device]
case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device != "CPU"), None)
case ViewOp(): continue # ViewOps are just ignored
case _: ji_graph_dev = None # Everything else is not graphed and flushes existing graph if it's being constructed
# Check if this jit item can be graphed at all, so check if a new graph supports the current item.
can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji)
# Check if the current batch can be extended with this item.
can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and \
graph_class(current_batch_devs[0]).supports_exec_item(dedup(current_batch_devs + [ji_graph_dev]), ji)
can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size)
# Flush the current batch if any, since it can't be extended or is full.
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
(current_batch if can_be_graphed else graphed_jit_cache).append(ji)
current_batch_devs = dedup(current_batch_devs + [ji_graph_dev]) if can_be_graphed else []
if len(current_batch) > 0: flush_batch()
return graphed_jit_cache
def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer],
orig_valid_positions: dict[int, set[int]]|None = None) -> dict[tuple[int, int], int]:
input_replace: dict[tuple[int, int], int] = {}
@@ -99,23 +98,30 @@ def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer],
return input_replace
class GraphRunner(Runner):
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
orig_valid_positions: dict[int, set[int]]|None = None):
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_buffers, orig_valid_positions)
def __init__(self, linear:UOp|None, input_buffers:list[Buffer]|None,
jit_cache:list[ExecItem]|None=None, input_replace:dict[tuple[int,int],int]|None=None):
# TODO: captured jit as linear?
if linear is not None:
jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])]
for b in jit_cache_bufs(jit_cache): b.ensure_allocated()
input_replace = get_input_replace(jit_cache, input_buffers) if input_buffers else {}
self.jit_cache, self.input_replace = unwrap(jit_cache), input_replace or {}
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
self.vars = sorted(var_vals.keys())
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
crs = [(ji, ji.prg) for ji in self.jit_cache if isinstance(ji.prg, CompiledRunner)]
self.vars = sorted({v.expr for ji,p in crs for v in p.p.vars if v.expr not in ji.fixedvars | p.p.runtimevars})
self.symbolic_dims = dedup([tuple(d) for _,p in crs if (d:=p.p.local_size) and is_sym_dim(d)] +
[tuple(d) for _,p in crs if (d:=p.p.global_size) and is_sym_dim(d)])
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
estimates = Estimates()
for j,ji in enumerate(jit_cache):
for j,ji in enumerate(self.jit_cache):
assert ji.prg is not None
estimates += ji.prg.estimates
if isinstance(ji.prg, CompiledRunner):
@@ -132,8 +138,10 @@ class GraphRunner(Runner):
self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
assert jit_cache[0].prg is not None
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
assert self.jit_cache[0].prg is not None
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), self.jit_cache[0].prg.device.split(":")[0], estimates.simplify())
def __reduce__(self): return self.__class__, (None, None, self.jit_cache, self.input_replace)
def updated_vars(self, var_vals: dict[str, int]):
vals = [var_vals[v] for v in self.vars]
@@ -165,18 +173,25 @@ class GraphRunner(Runner):
return list({id(x):x for x in wait_nodes}.values())
@staticmethod
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool: return isinstance(ei.prg, CompiledRunner) and len(dedup(devs)) == 1
def _all_devs(batch_devs:list[Compiled], new_call:UOp) -> list[Compiled]:
return dedup(batch_devs + [Device[x] for b in new_call.src[1:] if b.op is not Ops.BIND
for x in (b.device if isinstance(b.device, tuple) else (b.device,))])
@staticmethod
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM) and len(GraphRunner._all_devs(batch_devs, new_call)) == 1
# a marker for your graph supporting multiple devices of the same type
class MultiGraphRunner(GraphRunner):
@staticmethod
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
# Devices must be the same type
return isinstance(ei.prg, (CompiledRunner, BufferXfer)) and len(dedup([type(Device[b.device]) for b in ei.bufs if b]+[type(d) for d in devs]))==1
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
if isinstance(ei.prg, (BufferCopy, BufferXfer, EncDec)): return [cast(Buffer, ei.bufs[0])]
if isinstance(ei.prg, GraphRunner): return dedup([b for inner in ei.prg.jit_cache for b in get_out_buffers_for_ei(inner)])
return []
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
@@ -201,6 +216,7 @@ class CapturedJit(Generic[ReturnType]):
self._jit_cache: list[ExecItem] = self.jit_cache
self._input_replace: dict[tuple[int, int], int] = self.input_replace
self._first_run = True
self._needs_rebuild = False
# precompute read-after-write hazard detection
self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)}
self._input_to_max_reader: dict[int, int] = {}
@@ -217,12 +233,13 @@ class CapturedJit(Generic[ReturnType]):
depends: set[Buffer|None] = set([None])
update_depends(depends, self.jit_cache)
arenas = {b._base for b in depends if b is not None and b._base is not None}
to_free = {b for b in depends if b is not None} | {b for ei in self.jit_cache for b in ei.bufs if b is not None and b._base in arenas}
to_free = {b for b in depends if b is not None} | {b for b in jit_cache_bufs(self.jit_cache) if b._base in arenas}
for b in to_free:
if hasattr(b, '_buf'): b.deallocate()
for a in arenas:
if a.allocated_views == 0 and a.is_allocated(): a.deallocate()
self.__post_init__()
self._needs_rebuild = True
# jit exec
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
@@ -237,23 +254,13 @@ class CapturedJit(Generic[ReturnType]):
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
# Condense the items into a graph executor.
# allocate intermediates if freed on first run
if self._first_run:
# allocate intermediates if freed
for ji in self.jit_cache:
for b in ji.bufs:
if b is not None: b.ensure_allocated()
# create graph if needed
if JIT < 2:
# build a map from ExecItem object to the buffer positions that are valid inputs (from original input_replace)
orig_valid_positions: dict[int, set[int]] = {} # id(ExecItem) -> set of valid buffer indices
for (j, i) in self.input_replace: orig_valid_positions.setdefault(id(self.jit_cache[j]), set()).add(i)
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, orig_valid_positions, max_batch_size=JIT_BATCH_SIZE.value)
# recompute input_replace: GraphRunner items have all positions valid, non-GraphRunner items use orig_valid_positions
valid_positions = {id(ji): set(range(len(ji.bufs))) if isinstance(ji.prg, GraphRunner) else orig_valid_positions.get(id(ji), set())
for ji in self._jit_cache}
self._input_replace = get_input_replace(self._jit_cache, input_buffers, valid_positions)
self._first_run = False
for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated()
if self._needs_rebuild:
for ei in self.jit_cache:
if isinstance(ei.prg, GraphRunner): ei.prg = type(ei.prg)(None, None, ei.prg.jit_cache, ei.prg.input_replace)
self._first_run = self._needs_rebuild = False
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
for ei in self._jit_cache: ei.run(var_vals, jit=True)
@@ -341,7 +348,7 @@ class TinyJit(Generic[ReturnType]):
held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER}
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
jit_cache = jit_lower(big_linear, held_bufs)
jit_cache = jit_lower(big_linear, held_bufs, input_buffers)
# track inputs that are views of buffers
# TODO: eventually expected_buffers should live in ExecItem

View File

@@ -131,6 +131,7 @@ si_lowerer = PatternMatcher([
if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))),
(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="cf"), lambda ctx,cf: EncDec(cf, ctx[0].nbytes, ctx[0].device)),
(UPat(Ops.CUSTOM_FUNCTION, arg="graph", name="cf"), lambda ctx,cf: Device[cf.device if isinstance(cf.device,str) else cf.device[0]].graph(cf, ctx))
])
@dataclass

View File

@@ -74,7 +74,9 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]:
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND]
metadata = si.arg.metadata
if any(isinstance(x, MultiBuffer) for x in ubufs):
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph":
schedule.append(ExecItem(ast, flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in ubufs]), metadata))
elif any(isinstance(x, MultiBuffer) for x in ubufs):
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in ast.variables() if x.expr == '_device_num']
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):

View File

@@ -1,12 +1,12 @@
import collections, time
from typing import Any, cast
from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing, TracingKey
from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, suppress_finalizing, TracingKey
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp, Variable
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, BufferCopy
from tinygrad.engine.jit import MultiGraphRunner
from tinygrad.uop.ops import UOp, Ops, Variable
from tinygrad.engine.realize import BufferXfer, CompiledRunner, BufferCopy
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner
class HCQGraph(MultiGraphRunner):
def __init__(self, *args, **kwargs):
@@ -239,9 +239,9 @@ class HCQGraph(MultiGraphRunner):
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
@staticmethod
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
# Check if all devices are HCQ
all_devs = cast(list[HCQCompiled], dedup(devs + [Device[b.device] for b in ei.bufs if b]))
all_devs = cast(list[HCQCompiled], GraphRunner._all_devs(batch_devs, new_call))
if not all(issubclass(type(d), HCQCompiled) for d in all_devs): return False
# If all of devices are mapped into CPU address space, can use CPU inside the peer group.
@@ -250,6 +250,8 @@ class HCQGraph(MultiGraphRunner):
# Check if all devices are within the same peer group. If CPU is supported, don't count it as a separate peer group.
if len(set(d.peer_group for d in all_devs if not (cpu_support and d._is_cpu()))) > 1: return False
# MOCKGPU is not supported, since it can't execute commands in parallel
copy = (isinstance(ei.prg, BufferCopy) and cast(HCQCompiled, devs[0]).hw_copy_queue_t is not None) and not getenv("MOCKGPU")
return isinstance(ei.prg, (CompiledRunner, BufferXfer)) or copy
if new_call.src[0].op is Ops.COPY:
# MOCKGPU is not supported, since it can't execute commands in parallel
is_xfer = len(set(type(d) for d in all_devs)) == 1 and hasattr(alc:=all_devs[0].allocator, '_transfer') and alc.supports_transfer
return is_xfer or (all_devs[0].hw_copy_queue_t is not None and not getenv("MOCKGPU"))
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM)

View File

@@ -3,9 +3,10 @@ import ctypes, re, decimal
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.realize import CompiledRunner
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.runtime.ops_metal import wait_check, to_ns_str, MetalBuffer
from tinygrad.runtime.ops_metal import wait_check, to_ns_str
from tinygrad.runtime.autogen import metal
from tinygrad.runtime.support import objc
@@ -109,7 +110,7 @@ class MetalGraph(GraphRunner):
self.collect_timestamps()
@staticmethod
def supports_exec_item(devs, ei:ExecItem) -> bool:
def supports_exec_item(batch_devs, new_call:UOp) -> bool:
# Metal ICB replay encodes offsets as uint32; reject if any Metal buffer offset exceeds 32-bit range.
if any(b is not None and isinstance(b._buf, MetalBuffer) and b._buf.offset > 0xFFFFFFFF for b in ei.bufs): return False
return GraphRunner.supports_exec_item(devs, ei)
if any(b.op is Ops.BUFFER_VIEW and b.arg[1] * b.dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False
return GraphRunner.supports_exec_item(batch_devs, new_call)