mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
jit: graphing in uops (#15489)
* jit: graphing as rewrite rule * f * +metal,cuda * x * cl * x * x * simpler * f * m * x * revert? * revert2 * back * back * t * x * m * x * c * x * l * x * comment * smaller * rv * x * x
This commit is contained in:
@@ -71,7 +71,7 @@ class TestSQTTProfiler(unittest.TestCase):
|
||||
for i,s in enumerate(sqtt[1:], start=1): self.assertEqual(s["name"], f"{kernel_name} n{i+1}")
|
||||
|
||||
# TODO: can we trace SQTT for graphed kernels?
|
||||
def test_jit_graph(self, kernel_count=3*2):
|
||||
def test_jit_graph(self, kernel_count=3*1):
|
||||
@TinyJit
|
||||
def f(a): return ((a + 1).contiguous() + 2).contiguous().sum()
|
||||
t = Tensor.empty(32)
|
||||
|
||||
@@ -82,7 +82,7 @@ def helper_test_graphs(graph_impl, graphs, runs=RUN_CNT):
|
||||
ground_truth_np = [np.frombuffer(x, _to_np_dtype(bufs[i].dtype)) for i,x in enumerate(ground_thruth_bufs)]
|
||||
|
||||
# Build graphs
|
||||
gr_ji = [ExecItem(UOp(Ops.NOOP), [], prg=graph_impl(graph, [], {})) for graph in graphs]
|
||||
gr_ji = [ExecItem(UOp(Ops.NOOP), [], prg=graph_impl(None, None, graph)) for graph in graphs]
|
||||
|
||||
for _ in range(runs):
|
||||
test_bufs = helper_run_jit(gr_ji, bufs, out_buffers)
|
||||
|
||||
@@ -577,7 +577,7 @@ class TestJitPrune(unittest.TestCase):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_noprune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
assert len(w2_noprune.captured.jit_cache) == 2
|
||||
assert_jit_cache_len(w2_noprune, 2)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
|
||||
@@ -134,7 +134,7 @@ class TestProfiler(unittest.TestCase):
|
||||
_, _ = helper_profile_filter_device(profile, TestProfiler.d0.device)
|
||||
_, _ = helper_profile_filter_device(profile, d1.device)
|
||||
|
||||
assert len(graph_evs) == 1, "one graph event is expected"
|
||||
assert len(graph_evs) == 2, "2 graph events are expected"
|
||||
assert len(graph_evs[0].ents) == 2, "two entities are expected"
|
||||
|
||||
@unittest.skipIf(CI or not issubclass(type(Device[Device.DEFAULT]), HCQCompiled), "skip CI")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
||||
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
|
||||
from test.helpers import assert_jit_cache_len
|
||||
|
||||
def apply_rope(x:Tensor, start_pos:int):
|
||||
B, H, T, Hd = x.shape
|
||||
@@ -28,12 +29,8 @@ class TestAttention(unittest.TestCase):
|
||||
for _ in range(3):
|
||||
rope_noprune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
|
||||
rope_prune(Tensor.randn(1, 2, 4, 8, dtype=dtypes.float32), v_pos.bind(1))
|
||||
noprune_size = len(rope_noprune.captured.jit_cache)
|
||||
prune_size = len(rope_prune.captured.jit_cache)
|
||||
|
||||
self.assertGreater(noprune_size, prune_size)
|
||||
self.assertGreaterEqual(noprune_size, 2)
|
||||
self.assertEqual(prune_size, 1)
|
||||
assert_jit_cache_len(rope_prune, 1)
|
||||
assert_jit_cache_len(rope_noprune, 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import unittest
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.runtime.support.usb import USBMMIOInterface
|
||||
@@ -19,27 +20,23 @@ class TestHCQUnit(unittest.TestCase):
|
||||
inp, inp_cpu = Tensor.randn(10, 10, device=Device.DEFAULT).realize(), Tensor.randn(10, 10, device="CPU").realize()
|
||||
for _ in range(5): f(inp, inp_cpu)
|
||||
|
||||
gpu_ei, cpu_ei, gpu_devs = None, None, []
|
||||
for ji in f.captured.jit_cache:
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
if ji.prg.dev._is_cpu(): cpu_ei = ji
|
||||
else:
|
||||
gpu_ei = ji
|
||||
if ji.prg.dev not in gpu_devs: gpu_devs.append(ji.prg.dev)
|
||||
assert gpu_ei is not None and cpu_ei is not None and len(gpu_devs) > 0
|
||||
# construct minimal CALL UOps for supports_exec_item
|
||||
gpu_call = UOp(Ops.SINK).call(UOp.new_buffer(Device.DEFAULT, 1, dtypes.float))
|
||||
cpu_call = UOp(Ops.SINK).call(UOp.new_buffer("CPU", 1, dtypes.float))
|
||||
gpu_devs = [d0]
|
||||
|
||||
# local MMIO: GPU works alone and with CPU in batch (cpu_support=True)
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, gpu_ei) is True
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, cpu_ei) is True
|
||||
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_ei) is True
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is True
|
||||
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is True
|
||||
|
||||
# USB MMIO: GPU-only still works, but CPU batching must be rejected (cpu_support=False)
|
||||
orig_view = d0.timeline_signal.base_buf.view
|
||||
try:
|
||||
d0.timeline_signal.base_buf.view = USBMMIOInterface(MockUSB(bytearray(256)), 0, 16, fmt='B')
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, gpu_ei) is True
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, cpu_ei) is False
|
||||
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_ei) is False
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, gpu_call) is True
|
||||
assert HCQGraph.supports_exec_item(gpu_devs, cpu_call) is False
|
||||
assert HCQGraph.supports_exec_item(gpu_devs + [cpu_dev], gpu_call) is False
|
||||
finally:
|
||||
d0.timeline_signal.base_buf.view = orig_view
|
||||
|
||||
|
||||
@@ -1,34 +1,44 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from tinygrad import Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "METAL", "Metal device required to run")
|
||||
class TestMetalGraph(unittest.TestCase):
|
||||
def setUp(self):
|
||||
from tinygrad.runtime.graph.metal import MetalGraph
|
||||
from tinygrad.runtime.ops_metal import MetalBuffer
|
||||
self.MetalGraph = MetalGraph
|
||||
self.MetalBuffer = MetalBuffer
|
||||
self.dev = Device[Device.DEFAULT]
|
||||
|
||||
def metal_buf(self, offset): return MagicMock(_buf=self.MetalBuffer(MagicMock(), 4, offset))
|
||||
def metal_buf(self, offset):
|
||||
buf = MagicMock()
|
||||
if offset > 0:
|
||||
buf.op = Ops.BUFFER_VIEW
|
||||
buf.arg = (None, offset)
|
||||
buf.dtype = dtypes.uint8
|
||||
else:
|
||||
buf.op = Ops.BUFFER
|
||||
buf.device = Device.DEFAULT
|
||||
return buf
|
||||
|
||||
def ei(self, *bufs):
|
||||
ei = MagicMock()
|
||||
ei.prg = MagicMock(spec=CompiledRunner)
|
||||
ei.bufs = list(bufs)
|
||||
return ei
|
||||
def call(self, *bufs):
|
||||
c = MagicMock()
|
||||
c.src = (MagicMock(op=Ops.PROGRAM),) + tuple(bufs)
|
||||
return c
|
||||
|
||||
def test_supports_exec_item_normal_offset(self):
|
||||
assert self.MetalGraph.supports_exec_item([self.dev], self.ei(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True
|
||||
assert self.MetalGraph.supports_exec_item([self.dev], self.call(self.metal_buf(0), self.metal_buf(100), self.metal_buf(0xFFFFFFFF))) is True
|
||||
|
||||
def test_supports_exec_item_overflow_offset(self):
|
||||
assert self.MetalGraph.supports_exec_item([self.dev], self.ei(self.metal_buf(0), self.metal_buf(0x100000000))) is False
|
||||
assert self.MetalGraph.supports_exec_item([self.dev], self.call(self.metal_buf(0), self.metal_buf(0x100000000))) is False
|
||||
|
||||
def test_supports_exec_item_nonmetal_buf(self):
|
||||
# HCQBuffer.offset is a method, not an int — must not crash
|
||||
self.MetalGraph.supports_exec_item([self.dev], self.ei(MagicMock(**{"_buf.offset": lambda: 0})))
|
||||
# non-BUFFER_VIEW ops should not be checked for offset
|
||||
buf = MagicMock()
|
||||
buf.op = Ops.BUFFER
|
||||
buf.device = Device.DEFAULT
|
||||
self.MetalGraph.supports_exec_item([self.dev], self.call(buf))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import TypeVar, Generic, Callable, cast, Any
|
||||
import functools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap, pluralize
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, JIT_BATCH_SIZE, dedup, unwrap, pluralize, VIZ
|
||||
from tinygrad.device import Buffer, Compiled, Device, MultiBuffer
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.uop.ops import UOp, Variable, sym_infer, Ops, buffers, track_rewrites
|
||||
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, Variable, sym_infer, Ops, buffers, track_rewrites, graph_rewrite
|
||||
from tinygrad.engine.realize import ExecItem, capturing, BufferCopy, BufferXfer, EncDec, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.engine.memory import memory_plan_rewrite, _collect_bufs
|
||||
from tinygrad.engine.schedule import linear_to_schedule
|
||||
from tinygrad.nn.state import get_parameters
|
||||
@@ -22,9 +22,57 @@ def prune_linear(linear:UOp, needed:set[UOp]) -> tuple[UOp, UOp]:
|
||||
else: onetime.append(si)
|
||||
return linear.replace(src=tuple(kept)), linear.replace(src=tuple(onetime))
|
||||
|
||||
@track_rewrites(lambda linear,held_bufs,ret: f"JIT {pluralize('Kernel', len(ret))}")
|
||||
def jit_lower(linear:UOp, held_bufs:set[UOp]) -> list[ExecItem]:
|
||||
return [ei.lower() for ei in linear_to_schedule(memory_plan_rewrite(linear, held_bufs))]
|
||||
def create_graph_call(batch:list[UOp], input_buffers:set[Buffer]) -> UOp:
|
||||
def bufs_for(b): return b.buffer.bufs if isinstance(b.buffer, MultiBuffer) else [b.buffer]
|
||||
|
||||
input_list = dedup(b for si in batch for b in si.src[1:] if b.op in (Ops.BUFFER, Ops.BUFFER_VIEW) and not input_buffers.isdisjoint(bufs_for(b)))
|
||||
cf = UOp(Ops.CUSTOM_FUNCTION, dtypes.void, src=(UOp(Ops.LINEAR, src=tuple(batch)), *input_list), arg="graph")
|
||||
return cf.call(*input_list, metadata=tuple(m for si in batch for m in si.arg.metadata))
|
||||
|
||||
def graph_split_rewrite(linear:UOp, input_buffers:set[Buffer], max_batch_size:int=0) -> UOp:
|
||||
new_src: list[UOp] = []
|
||||
current_batch: list[UOp] = []
|
||||
current_batch_devs: list[Compiled] = []
|
||||
|
||||
def flush_batch():
|
||||
nonlocal current_batch, current_batch_devs, max_batch_size, new_src
|
||||
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): new_src.extend(current_batch)
|
||||
else:
|
||||
new_src.append(create_graph_call(current_batch, input_buffers))
|
||||
max_batch_size *= 2
|
||||
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels")
|
||||
current_batch, current_batch_devs = [], []
|
||||
|
||||
for si in linear.src:
|
||||
if si.src[0].op is Ops.BUFFER_VIEW: continue
|
||||
|
||||
devs = [Device[x] for x in (si.device if isinstance(si.device, tuple) else (si.device,))]
|
||||
graph_t = graph_class(devs[0]) if devs[0].graph is not None else None
|
||||
|
||||
can_graph = graph_t is not None and graph_t.supports_exec_item(devs, si)
|
||||
can_extend = can_graph and graph_t is not None and (not current_batch_devs or graph_t.supports_exec_item(current_batch_devs, si)) \
|
||||
and (max_batch_size == 0 or len(current_batch) < max_batch_size)
|
||||
if not can_extend and current_batch: flush_batch()
|
||||
|
||||
# append this si and update devs
|
||||
(current_batch if can_graph else new_src).append(si)
|
||||
current_batch_devs = dedup(current_batch_devs + devs) if can_graph else []
|
||||
if current_batch: flush_batch()
|
||||
return linear.replace(src=tuple(new_src))
|
||||
|
||||
def jit_cache_bufs(jit_cache:list[ExecItem]):
|
||||
for ei in jit_cache:
|
||||
for b in ei.bufs:
|
||||
if b is not None: yield b
|
||||
if isinstance(ei.prg, GraphRunner): yield from jit_cache_bufs(ei.prg.jit_cache)
|
||||
|
||||
@track_rewrites(lambda linear,held_bufs,input_buffers=None,ret=(): f"JIT {pluralize('call', len(linear.src))}")
|
||||
def jit_lower(linear:UOp, held_bufs:set[UOp], input_buffers:list[Buffer]|None=None) -> list[ExecItem]:
|
||||
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View captured linear")
|
||||
linear = memory_plan_rewrite(linear, held_bufs)
|
||||
if JIT < 2: linear = graph_split_rewrite(linear, set(input_buffers or []), max_batch_size=JIT_BATCH_SIZE.value)
|
||||
if VIZ: graph_rewrite(linear, PatternMatcher([]), name="View graphed linear")
|
||||
return [ei.lower() for ei in linear_to_schedule(linear)]
|
||||
|
||||
class GraphException(Exception): pass
|
||||
class JitError(Exception): pass
|
||||
@@ -38,55 +86,6 @@ def _check_no_non_tensor_return(ret):
|
||||
|
||||
def graph_class(dev): return dev.graph.func if isinstance(dev.graph, functools.partial) else dev.graph
|
||||
|
||||
def apply_graph_to_jit(jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
|
||||
orig_valid_positions: dict[int, set[int]]|None = None, max_batch_size=0) -> list[ExecItem]:
|
||||
# Split JIT cache into batches for faster graph execution.
|
||||
# This allows the accelerator to run some batches while subsequent graphs are still being updated.
|
||||
graphed_jit_cache: list[ExecItem] = []
|
||||
current_batch: list[ExecItem] = []
|
||||
current_batch_devs: list[Compiled] = []
|
||||
|
||||
def flush_batch():
|
||||
nonlocal current_batch, current_batch_devs, max_batch_size
|
||||
try:
|
||||
if len(current_batch_devs) == 0: raise GraphException("no device for graph")
|
||||
if len(current_batch) <= 1 and not getenv("GRAPH_ONE_KERNEL"): raise GraphException("only one kernel doesn't graph")
|
||||
graph_runner = current_batch_devs[0].graph(current_batch, input_buffers, var_vals, orig_valid_positions=orig_valid_positions)
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
||||
graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_buffers), prg=graph_runner))
|
||||
max_batch_size *= 2
|
||||
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}")
|
||||
except GraphException as e:
|
||||
graphed_jit_cache.extend(current_batch)
|
||||
if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_batch_devs[0]}: {e}")
|
||||
current_batch = []
|
||||
current_batch_devs = []
|
||||
|
||||
for ji in jit_cache:
|
||||
match ji.prg:
|
||||
case CompiledRunner(): ji_graph_dev = ji.prg.dev
|
||||
case BufferXfer(): ji_graph_dev = Device[unwrap(ji.bufs[0]).device]
|
||||
case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device != "CPU"), None)
|
||||
case ViewOp(): continue # ViewOps are just ignored
|
||||
case _: ji_graph_dev = None # Everything else is not graphed and flushes existing graph if it's being constructed
|
||||
|
||||
# Check if this jit item can be graphed at all, so check if a new graph supports the current item.
|
||||
can_be_graphed = ji_graph_dev is not None and ji_graph_dev.graph is not None and graph_class(ji_graph_dev).supports_exec_item([ji_graph_dev], ji)
|
||||
|
||||
# Check if the current batch can be extended with this item.
|
||||
can_share_graph = can_be_graphed and len(current_batch_devs) > 0 and \
|
||||
graph_class(current_batch_devs[0]).supports_exec_item(dedup(current_batch_devs + [ji_graph_dev]), ji)
|
||||
can_extend_graph_batch = can_share_graph and (max_batch_size == 0 or len(current_batch) < max_batch_size)
|
||||
|
||||
# Flush the current batch if any, since it can't be extended or is full.
|
||||
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
||||
(current_batch if can_be_graphed else graphed_jit_cache).append(ji)
|
||||
current_batch_devs = dedup(current_batch_devs + [ji_graph_dev]) if can_be_graphed else []
|
||||
|
||||
if len(current_batch) > 0: flush_batch()
|
||||
return graphed_jit_cache
|
||||
|
||||
def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer],
|
||||
orig_valid_positions: dict[int, set[int]]|None = None) -> dict[tuple[int, int], int]:
|
||||
input_replace: dict[tuple[int, int], int] = {}
|
||||
@@ -99,23 +98,30 @@ def get_input_replace(jit_cache: list[ExecItem], input_buffers:list[Buffer],
|
||||
return input_replace
|
||||
|
||||
class GraphRunner(Runner):
|
||||
def __init__(self, jit_cache: list[ExecItem], input_buffers: list[Buffer], var_vals: dict[str, int],
|
||||
orig_valid_positions: dict[int, set[int]]|None = None):
|
||||
self.jit_cache = jit_cache # NOTE: this is not used, but you have to keep these objects alive for the Graph
|
||||
self.input_replace:dict[tuple[int, int], int] = get_input_replace(jit_cache, input_buffers, orig_valid_positions)
|
||||
def __init__(self, linear:UOp|None, input_buffers:list[Buffer]|None,
|
||||
jit_cache:list[ExecItem]|None=None, input_replace:dict[tuple[int,int],int]|None=None):
|
||||
# TODO: captured jit as linear?
|
||||
if linear is not None:
|
||||
jit_cache = [ei.lower() for ei in linear_to_schedule(linear.src[0])]
|
||||
for b in jit_cache_bufs(jit_cache): b.ensure_allocated()
|
||||
input_replace = get_input_replace(jit_cache, input_buffers) if input_buffers else {}
|
||||
self.jit_cache, self.input_replace = unwrap(jit_cache), input_replace or {}
|
||||
|
||||
self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
|
||||
self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
|
||||
self.launch_dims_base:dict[int, tuple[tuple[int, ...], tuple[int, ...]]] = {}
|
||||
|
||||
def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
|
||||
|
||||
self.vars = sorted(var_vals.keys())
|
||||
self.symbolic_dims = dedup([tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.local_size) and is_sym_dim(d)] +
|
||||
[tuple(d) for ji in jit_cache if isinstance(ji.prg, CompiledRunner) and (d:=ji.prg.p.global_size) and is_sym_dim(d)])
|
||||
crs = [(ji, ji.prg) for ji in self.jit_cache if isinstance(ji.prg, CompiledRunner)]
|
||||
self.vars = sorted({v.expr for ji,p in crs for v in p.p.vars if v.expr not in ji.fixedvars | p.p.runtimevars})
|
||||
self.symbolic_dims = dedup([tuple(d) for _,p in crs if (d:=p.p.local_size) and is_sym_dim(d)] +
|
||||
[tuple(d) for _,p in crs if (d:=p.p.global_size) and is_sym_dim(d)])
|
||||
|
||||
def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
|
||||
|
||||
estimates = Estimates()
|
||||
for j,ji in enumerate(jit_cache):
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
assert ji.prg is not None
|
||||
estimates += ji.prg.estimates
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
@@ -132,8 +138,10 @@ class GraphRunner(Runner):
|
||||
self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
|
||||
self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
|
||||
|
||||
assert jit_cache[0].prg is not None
|
||||
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
|
||||
assert self.jit_cache[0].prg is not None
|
||||
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), self.jit_cache[0].prg.device.split(":")[0], estimates.simplify())
|
||||
|
||||
def __reduce__(self): return self.__class__, (None, None, self.jit_cache, self.input_replace)
|
||||
|
||||
def updated_vars(self, var_vals: dict[str, int]):
|
||||
vals = [var_vals[v] for v in self.vars]
|
||||
@@ -165,18 +173,25 @@ class GraphRunner(Runner):
|
||||
return list({id(x):x for x in wait_nodes}.values())
|
||||
|
||||
@staticmethod
|
||||
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool: return isinstance(ei.prg, CompiledRunner) and len(dedup(devs)) == 1
|
||||
def _all_devs(batch_devs:list[Compiled], new_call:UOp) -> list[Compiled]:
|
||||
return dedup(batch_devs + [Device[x] for b in new_call.src[1:] if b.op is not Ops.BIND
|
||||
for x in (b.device if isinstance(b.device, tuple) else (b.device,))])
|
||||
|
||||
@staticmethod
|
||||
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
|
||||
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM) and len(GraphRunner._all_devs(batch_devs, new_call)) == 1
|
||||
|
||||
# a marker for your graph supporting multiple devices of the same type
|
||||
class MultiGraphRunner(GraphRunner):
|
||||
@staticmethod
|
||||
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
|
||||
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
|
||||
# Devices must be the same type
|
||||
return isinstance(ei.prg, (CompiledRunner, BufferXfer)) and len(dedup([type(Device[b.device]) for b in ei.bufs if b]+[type(d) for d in devs]))==1
|
||||
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM, Ops.COPY) and len(dedup([type(d) for d in GraphRunner._all_devs(batch_devs, new_call)])) == 1
|
||||
|
||||
def get_out_buffers_for_ei(ei:ExecItem) -> list[Buffer]:
|
||||
if isinstance(ei.prg, CompiledRunner): return [cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs if out not in ei.prg.p.ins]
|
||||
if isinstance(ei.prg, (BufferCopy, BufferXfer, EncDec)): return [cast(Buffer, ei.bufs[0])]
|
||||
if isinstance(ei.prg, GraphRunner): return dedup([b for inner in ei.prg.jit_cache for b in get_out_buffers_for_ei(inner)])
|
||||
return []
|
||||
|
||||
def update_depends(depends:set[Buffer|None], jit_cache:list[ExecItem]):
|
||||
@@ -201,6 +216,7 @@ class CapturedJit(Generic[ReturnType]):
|
||||
self._jit_cache: list[ExecItem] = self.jit_cache
|
||||
self._input_replace: dict[tuple[int, int], int] = self.input_replace
|
||||
self._first_run = True
|
||||
self._needs_rebuild = False
|
||||
# precompute read-after-write hazard detection
|
||||
self._output_to_writer = {b: j for j, ei in enumerate(self.jit_cache) for b in get_out_buffers_for_ei(ei)}
|
||||
self._input_to_max_reader: dict[int, int] = {}
|
||||
@@ -217,12 +233,13 @@ class CapturedJit(Generic[ReturnType]):
|
||||
depends: set[Buffer|None] = set([None])
|
||||
update_depends(depends, self.jit_cache)
|
||||
arenas = {b._base for b in depends if b is not None and b._base is not None}
|
||||
to_free = {b for b in depends if b is not None} | {b for ei in self.jit_cache for b in ei.bufs if b is not None and b._base in arenas}
|
||||
to_free = {b for b in depends if b is not None} | {b for b in jit_cache_bufs(self.jit_cache) if b._base in arenas}
|
||||
for b in to_free:
|
||||
if hasattr(b, '_buf'): b.deallocate()
|
||||
for a in arenas:
|
||||
if a.allocated_views == 0 and a.is_allocated(): a.deallocate()
|
||||
self.__post_init__()
|
||||
self._needs_rebuild = True
|
||||
|
||||
# jit exec
|
||||
def __call__(self, input_buffers:list[Buffer], var_vals:dict[str, int]) -> ReturnType:
|
||||
@@ -237,23 +254,13 @@ class CapturedJit(Generic[ReturnType]):
|
||||
|
||||
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
|
||||
|
||||
# Condense the items into a graph executor.
|
||||
# allocate intermediates if freed on first run
|
||||
if self._first_run:
|
||||
# allocate intermediates if freed
|
||||
for ji in self.jit_cache:
|
||||
for b in ji.bufs:
|
||||
if b is not None: b.ensure_allocated()
|
||||
# create graph if needed
|
||||
if JIT < 2:
|
||||
# build a map from ExecItem object to the buffer positions that are valid inputs (from original input_replace)
|
||||
orig_valid_positions: dict[int, set[int]] = {} # id(ExecItem) -> set of valid buffer indices
|
||||
for (j, i) in self.input_replace: orig_valid_positions.setdefault(id(self.jit_cache[j]), set()).add(i)
|
||||
self._jit_cache = apply_graph_to_jit(self.jit_cache, input_buffers, var_vals, orig_valid_positions, max_batch_size=JIT_BATCH_SIZE.value)
|
||||
# recompute input_replace: GraphRunner items have all positions valid, non-GraphRunner items use orig_valid_positions
|
||||
valid_positions = {id(ji): set(range(len(ji.bufs))) if isinstance(ji.prg, GraphRunner) else orig_valid_positions.get(id(ji), set())
|
||||
for ji in self._jit_cache}
|
||||
self._input_replace = get_input_replace(self._jit_cache, input_buffers, valid_positions)
|
||||
self._first_run = False
|
||||
for b in jit_cache_bufs(self.jit_cache): b.ensure_allocated()
|
||||
if self._needs_rebuild:
|
||||
for ei in self.jit_cache:
|
||||
if isinstance(ei.prg, GraphRunner): ei.prg = type(ei.prg)(None, None, ei.prg.jit_cache, ei.prg.input_replace)
|
||||
self._first_run = self._needs_rebuild = False
|
||||
|
||||
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
|
||||
for ei in self._jit_cache: ei.run(var_vals, jit=True)
|
||||
@@ -341,7 +348,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
|
||||
held_bufs = set(buffers) | {t.uop.buf_uop for t in get_parameters(ret) if t.uop.buf_uop.op is Ops.BUFFER}
|
||||
with Context(BEAM=getenv("JITBEAM", BEAM.value)):
|
||||
jit_cache = jit_lower(big_linear, held_bufs)
|
||||
jit_cache = jit_lower(big_linear, held_bufs, input_buffers)
|
||||
|
||||
# track inputs that are views of buffers
|
||||
# TODO: eventually expected_buffers should live in ExecItem
|
||||
|
||||
@@ -131,6 +131,7 @@ si_lowerer = PatternMatcher([
|
||||
if hasattr(alc:=Device[ctx[0].device].allocator, '_transfer') and alc.supports_transfer and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="encdec", name="cf"), lambda ctx,cf: EncDec(cf, ctx[0].nbytes, ctx[0].device)),
|
||||
(UPat(Ops.CUSTOM_FUNCTION, arg="graph", name="cf"), lambda ctx,cf: Device[cf.device if isinstance(cf.device,str) else cf.device[0]].graph(cf, ctx))
|
||||
])
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -74,7 +74,9 @@ def linear_to_schedule(linear:UOp) -> list[ExecItem]:
|
||||
buffers[buf_uops[0]] = base.view(buf_uops[0].arg, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
ubufs = [b.buffer for b in buf_uops if b.op is not Ops.BIND]
|
||||
metadata = si.arg.metadata
|
||||
if any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
if ast.op is Ops.CUSTOM_FUNCTION and ast.arg == "graph":
|
||||
schedule.append(ExecItem(ast, flatten([b.bufs if isinstance(b, MultiBuffer) else [b] for b in ubufs]), metadata))
|
||||
elif any(isinstance(x, MultiBuffer) for x in ubufs):
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in ast.variables() if x.expr == '_device_num']
|
||||
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import collections, time
|
||||
from typing import Any, cast
|
||||
from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, dedup, suppress_finalizing, TracingKey
|
||||
from tinygrad.helpers import round_up, PROFILE, ALL2ALL, merge_dicts, getenv, suppress_finalizing, TracingKey
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator, MMIOInterface
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner, BufferCopy
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
from tinygrad.uop.ops import UOp, Ops, Variable
|
||||
from tinygrad.engine.realize import BufferXfer, CompiledRunner, BufferCopy
|
||||
from tinygrad.engine.jit import GraphRunner, MultiGraphRunner
|
||||
|
||||
class HCQGraph(MultiGraphRunner):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -239,9 +239,9 @@ class HCQGraph(MultiGraphRunner):
|
||||
for fdev, buf in self.kernargs_bufs.items(): fdev.allocator._free(buf, BufferSpec(cpu_access=True))
|
||||
|
||||
@staticmethod
|
||||
def supports_exec_item(devs:list[Compiled], ei:ExecItem) -> bool:
|
||||
def supports_exec_item(batch_devs:list[Compiled], new_call:UOp) -> bool:
|
||||
# Check if all devices are HCQ
|
||||
all_devs = cast(list[HCQCompiled], dedup(devs + [Device[b.device] for b in ei.bufs if b]))
|
||||
all_devs = cast(list[HCQCompiled], GraphRunner._all_devs(batch_devs, new_call))
|
||||
if not all(issubclass(type(d), HCQCompiled) for d in all_devs): return False
|
||||
|
||||
# If all of devices are mapped into CPU address space, can use CPU inside the peer group.
|
||||
@@ -250,6 +250,8 @@ class HCQGraph(MultiGraphRunner):
|
||||
# Check if all devices are within the same peer group. If CPU is supported, don't count it as a separate peer group.
|
||||
if len(set(d.peer_group for d in all_devs if not (cpu_support and d._is_cpu()))) > 1: return False
|
||||
|
||||
# MOCKGPU is not supported, since it can't execute commands in parallel
|
||||
copy = (isinstance(ei.prg, BufferCopy) and cast(HCQCompiled, devs[0]).hw_copy_queue_t is not None) and not getenv("MOCKGPU")
|
||||
return isinstance(ei.prg, (CompiledRunner, BufferXfer)) or copy
|
||||
if new_call.src[0].op is Ops.COPY:
|
||||
# MOCKGPU is not supported, since it can't execute commands in parallel
|
||||
is_xfer = len(set(type(d) for d in all_devs)) == 1 and hasattr(alc:=all_devs[0].allocator, '_transfer') and alc.supports_transfer
|
||||
return is_xfer or (all_devs[0].hw_copy_queue_t is not None and not getenv("MOCKGPU"))
|
||||
return new_call.src[0].op in (Ops.SINK, Ops.PROGRAM)
|
||||
|
||||
@@ -3,9 +3,10 @@ import ctypes, re, decimal
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, getenv, merge_dicts, PROFILE
|
||||
from tinygrad.device import Buffer, ProfileGraphEntry, ProfileGraphEvent
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.runtime.ops_metal import wait_check, to_ns_str, MetalBuffer
|
||||
from tinygrad.runtime.ops_metal import wait_check, to_ns_str
|
||||
from tinygrad.runtime.autogen import metal
|
||||
from tinygrad.runtime.support import objc
|
||||
|
||||
@@ -109,7 +110,7 @@ class MetalGraph(GraphRunner):
|
||||
self.collect_timestamps()
|
||||
|
||||
@staticmethod
|
||||
def supports_exec_item(devs, ei:ExecItem) -> bool:
|
||||
def supports_exec_item(batch_devs, new_call:UOp) -> bool:
|
||||
# Metal ICB replay encodes offsets as uint32; reject if any Metal buffer offset exceeds 32-bit range.
|
||||
if any(b is not None and isinstance(b._buf, MetalBuffer) and b._buf.offset > 0xFFFFFFFF for b in ei.bufs): return False
|
||||
return GraphRunner.supports_exec_item(devs, ei)
|
||||
if any(b.op is Ops.BUFFER_VIEW and b.arg[1] * b.dtype.itemsize > 0xFFFFFFFF for b in new_call.src[1:]): return False
|
||||
return GraphRunner.supports_exec_item(batch_devs, new_call)
|
||||
|
||||
Reference in New Issue
Block a user