From 8656eebb42d0773d1424085967ccffa0a0f30450 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 23 Nov 2023 00:13:18 -0800 Subject: [PATCH] jit doesn't use named tensors (#2393) * jit doesn't use named tensors * move to compile2 * remove broken single root junk * explicit float32 * skip slow test --- .github/workflows/test.yml | 9 +- openpilot/compile.py | 150 ---------------------------------- openpilot/compile2.py | 10 ++- openpilot/go.sh | 2 +- test/test_linearizer.py | 1 + tinygrad/features/image.py | 3 - tinygrad/graph.py | 4 +- tinygrad/jit.py | 14 ++-- tinygrad/lazy.py | 1 - tinygrad/ops.py | 16 ++-- tinygrad/runtime/ops_metal.py | 12 +-- 11 files changed, 36 insertions(+), 186 deletions(-) delete mode 100644 openpilot/compile.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5d0c2baee3..7922a02f87 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -178,17 +178,14 @@ jobs: - if: ${{ matrix.task == 'openpilot' }} name: Test openpilot model compile and size run: | - DEBUG=2 ALLOWED_KERNEL_COUNT=207 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py + DEBUG=2 ALLOWED_KERNEL_COUNT=207 VALIDTEST=1 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - if: ${{ matrix.task == 'openpilot' }} name: Test openpilot model correctness (float32) - run: DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py - - if: ${{ matrix.task == 'openpilot' }} - name: Test openpilot model correctness (float32, new compiler) - run: DEBUGCL=1 FLOAT16=0 python3 openpilot/compile2.py + run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py - if: ${{ matrix.task == 'openpilot' }} name: Test openpilot alt model correctness (float32) - run: DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx + run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx - if: ${{ matrix.task == 'openpilot' }} name: Test tensor core ops run: GPU=1 TC=2 python -m pytest -n=auto test/test_ops.py diff --git a/openpilot/compile.py b/openpilot/compile.py deleted file mode 100644 index 8ed27fcd95..0000000000 --- a/openpilot/compile.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 -import os, time, io, pathlib, sys, traceback, re -sys.path.insert(0, str(pathlib.Path(__file__).parents[1])) - -if os.getenv("OPT", None) is None: - os.environ['OPT'] = '99' -if os.getenv("GPU", None) is None: - os.environ['GPU'] = '1' -if os.getenv("IMAGE", None) is None: - os.environ['IMAGE'] = '2' - -from tinygrad.helpers import getenv, dtypes -ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0) -DEBUGCL = getenv("DEBUGCL", 0) - -import onnx -import numpy as np - -import tinygrad.graph as graph -from tinygrad.helpers import GlobalCounters -from tinygrad.jit import TinyJit, CacheCollector - -import pyopencl as cl -from tinygrad.runtime.ops_gpu import CL -from extra.utils import fetch -from extra.onnx import get_run_onnx -from tinygrad.tensor import Tensor - -OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx" - -np.random.seed(1337) -def get_random_input_tensors(input_shapes): - # this 16 is a random scale factor - inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()} - np_inputs = {k:v.realize().numpy() for k,v in inputs.items()} - return inputs, np_inputs - -@TinyJit -def model_exec(run_onnx, using_graph, **inputs): - ret = next(iter(run_onnx(inputs).values())).cast(dtypes.float32) - GlobalCounters.reset() - CacheCollector.start() # don't cache pre-realize - if using_graph: graph.GRAPH = True - print("realizing") - return ret.realize() - -def compile(dat, output_fn): - Tensor.manual_seed(1337) - Tensor.no_grad = True - using_graph = graph.GRAPH - if getenv("GRAPH") < 3: graph.GRAPH = False - - onnx_model = onnx.load(io.BytesIO(dat)) - run_onnx = get_run_onnx(onnx_model) - input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} - - inputs, np_inputs = get_random_input_tensors(input_shapes) - # run twice to trigger the JIT - for i in range(2): tinygrad_out = model_exec(run_onnx, i == 1 and using_graph, **inputs) - graph.GRAPH = False - print("kernel count:", len(model_exec.jit_cache)) - assert len(model_exec.jit_cache) <= ALLOWED_KERNEL_COUNT or ALLOWED_KERNEL_COUNT == 0, "too many kernels!" - - # pull out inputs and put them in the jit cache - input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()} - for (j,i),idx in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx] - - # transform to CL.CACHE - used_ops = 0 - cl_cache = [] - for ji in model_exec.jit_cache: - prg = ji.prg - # pass these to thneed - setattr(prg.clprg, 'op_estimate', prg.op_estimate) - setattr(prg.clprg, 'prg', prg.prg) - - if getenv("VALIDTEST") == 1: - src = re.search(r"=.*\?.*?read_image", prg.prg) - if src is not None: raise Exception("Openpilot has valid checks!") - - global_size = prg.global_size + [1]*(3-len(prg.global_size)) - local_size = prg.local_size + [1]*(3-len(prg.local_size)) - cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x._buf for x in ji.rawbufs]])) - used_ops += prg.op_estimate - - from extra.thneed import Thneed - t = Thneed(cl_cache, {k:v._buf for k,v in input_rawbuffers.items()}) - - # save thneed (before run) - t.save(output_fn) - - print(f"buffers to save: {len(t.buffers_to_save)}, inputs: {list(t.inputs.keys())}, outputs: {t.outputs}") - runtime = t.run() - print(f"network using {used_ops/1e9:.2f} GOPS with runtime {runtime*1e3:.2f} ms that's {used_ops/runtime*1e-9:.2f} GFLOPS") - - # confirm thneed found the right output - thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape) - cl.enqueue_copy(CL.cl_queue[0], thneed_out, t.outputs[0], is_blocking=True) - np.testing.assert_allclose(thneed_out, tinygrad_out.numpy()) - - # testing is float32 only (fix this) - FLOAT16 = getenv("FLOAT16", 0) - if FLOAT16 == 0: - try: - from test.models.test_onnx import run_onnx_torch - torch_out = run_onnx_torch(onnx_model, np_inputs).numpy() - print(thneed_out, torch_out, "mse", np.sum((thneed_out-torch_out)**2), "max err", np.max(np.abs((thneed_out-torch_out)))) - np.testing.assert_allclose(torch_out, thneed_out, atol=1e-4, rtol=1e-2) - - # test loading/run thneed - _, new_np_inputs = get_random_input_tensors(input_shapes) - new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy() - - # try old thneed with a different input - for k,v in t.inputs.items(): - cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True) - - t.run() - old_thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape) - cl.enqueue_copy(CL.cl_queue[0], old_thneed_out, t.outputs[0], is_blocking=True) - - # compare thneed (rerun) with torch - np.testing.assert_allclose(new_torch_out, old_thneed_out, atol=1e-4, rtol=1e-2) - - # load thneed and try that - _, new_np_inputs = get_random_input_tensors(input_shapes) - new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy() - nt = Thneed() - nt.load(output_fn) - - # inputs - for k,v in nt.inputs.items(): - cl.enqueue_copy(CL.cl_queue[0], v, new_np_inputs[k], is_blocking=True) - - nt.run() - new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape) - cl.enqueue_copy(CL.cl_queue[0], new_thneed_out, nt.outputs[0], is_blocking=True) - - # compare torch to thneed - np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2) - print("thneed self-test passed!") - except ModuleNotFoundError as e: - print(f"TEST NOT HAPPENING {e}") - - -# UNSAFE_FLOAT4=1 DEBUGCL=1 FLOAT16=1 python3 openpilot/compile.py -# 22.59 ms -if __name__ == "__main__": - dat = fetch(OPENPILOT_MODEL if len(sys.argv) == 1 else sys.argv[1]) - compile(dat, sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed") diff --git a/openpilot/compile2.py b/openpilot/compile2.py index 1ee84f91eb..94cf40da78 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -1,12 +1,11 @@ #!/usr/bin/env python3 -import os, sys, io, pathlib +import os, sys, io, pathlib, re sys.path.insert(0, str(pathlib.Path(__file__).parents[1])) if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1" if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2" if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1" if "OPT" not in os.environ: os.environ["OPT"] = "99" -os.environ["PREREALIZE"] = "0" OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx" @@ -55,6 +54,9 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]: def schedule_to_thneed(schedule, output_fn): from extra.thneed import Thneed + print("kernel count:", len(schedule)) + assert len(schedule) <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!" + # transform to CL.CACHE used_ops = 0 cl_cache = [] @@ -66,6 +68,10 @@ def schedule_to_thneed(schedule, output_fn): setattr(prg.clprg, 'op_estimate', prg.op_estimate) setattr(prg.clprg, 'prg', prg.prg) + if getenv("VALIDTEST") == 1: + src = re.search(r"=.*\?.*?read_image", prg.prg) + if src is not None: raise Exception("Openpilot has valid checks!") + global_size = prg.global_size + [1]*(3-len(prg.global_size)) local_size = prg.local_size + [1]*(3-len(prg.local_size)) cl_cache.append((prg.clprg, [[int(g*l) for g,l in zip(global_size, local_size)], local_size, *[x.realized._buf for x in args]])) diff --git a/openpilot/go.sh b/openpilot/go.sh index f27fc7f8eb..d99c706e77 100755 --- a/openpilot/go.sh +++ b/openpilot/go.sh @@ -1,2 +1,2 @@ #!/bin/bash -NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile.py +NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 openpilot/compile2.py diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 9701e2ddf1..0984ca0263 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -496,6 +496,7 @@ class TestLinearizerOpts(unittest.TestCase): def test_padto_matmul(self): if not isinstance(Device[Device.DEFAULT], Compiled): self.skipTest("Only Compiled uses linearizer") + if Device.DEFAULT == "CUDA": self.skipTest("super slow on CUDA/triton") N = 17 * 17 Tensor.manual_seed(289) a = Tensor.rand(N, N) diff --git a/tinygrad/features/image.py b/tinygrad/features/image.py index ee8e3517e6..1e843494b8 100644 --- a/tinygrad/features/image.py +++ b/tinygrad/features/image.py @@ -3,8 +3,6 @@ from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG, fla # *** image Tensor function replacements *** -from tinygrad.lazy import get_single_root - def image_dot(self, w): # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) n1, n2 = len(self.shape), len(w.shape) @@ -60,7 +58,6 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin # contiguous creates the image, and early realize static weights (TODO: test for the static weight) if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4))) x, w = x.contiguous(), w.contiguous() - if getenv("PREREALIZE", 1) and get_single_root(w.lazydata).realized: w.realize() # expand out rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1 diff --git a/tinygrad/graph.py b/tinygrad/graph.py index e8a5976b97..f0bbc4f224 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -7,7 +7,7 @@ from collections import defaultdict from typing import Dict, List from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters, getenv, dedup -from tinygrad.codegen.linearizer import UOps +from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import NumNode @@ -107,7 +107,7 @@ def _tree(lazydata, prefix=""): def print_tree(lazydata:LazyOp): print("\n".join([f"{str(i).rjust(3)} {s}" for i,s in enumerate(_tree(lazydata))])) -def graph_uops(uops): +def graph_uops(uops:List[UOp]): colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.SPECIAL: "#c0c0ff", UOps.CONST: "#e0e0e0", UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.LOOP: "#c8a0e0", UOps.PHI: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0"} diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 5ed94ed41b..93ea6400ad 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -20,12 +20,12 @@ class TinyJit(Generic[ReturnType]): self.cnt: int = 0 self.ret: Optional[ReturnType] = None self.expected_vals: Optional[Tuple[Variable, ...]] = None - self.expected_sts_dtype: Optional[Tuple[Tuple[ShapeTracker, DType], ...]] = None + self.expected_name_sts_dtype: Optional[Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...]] = None @property def jit_cache(self) -> List[JitItem]: return self.jit_fxn.jit_cache if self.jit_fxn else [] @property - def input_replace(self) -> Dict[Tuple[int, int], Union[int, str]]: return self.jit_fxn.input_replace if self.jit_fxn else {} + def input_replace(self) -> Dict[Tuple[int, int], int]: return self.jit_fxn.input_replace if self.jit_fxn else {} # add support for instance methods def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) @@ -33,11 +33,11 @@ class TinyJit(Generic[ReturnType]): def __call__(self, *args, **kwargs) -> ReturnType: # all inputs are realized input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor} - expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()]) + expected_name_sts_dtype = tuple([(k, v.lazydata.st.unbind(), v.dtype) for k,v in input_tensors.items()]) # get rawbuffers - input_rawbuffers: Dict[Union[int, str], RawBuffer] = {k:cast(RawBuffer, v.lazydata.realized) for k,v in input_tensors.items()} - assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT" + input_rawbuffers: List[RawBuffer] = [cast(RawBuffer, v.lazydata.realized) for v in input_tensors.values()] + assert len(set(input_rawbuffers)) == len(input_rawbuffers), "duplicate inputs to JIT" # get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) @@ -45,11 +45,11 @@ class TinyJit(Generic[ReturnType]): if self.cnt >= 2: assert self.expected_vals == expected_vals, "mismatch of var_vals" - assert self.expected_sts_dtype == expected_sts_dtype, f"mismatch of sts, expected {self.expected_sts_dtype} got {expected_sts_dtype}" + assert self.expected_name_sts_dtype == expected_name_sts_dtype, f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}" assert self.jit_fxn, "didn't get jitted?" self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2) elif self.cnt == 1: - self.expected_vals, self.expected_sts_dtype = expected_vals, expected_sts_dtype + self.expected_vals, self.expected_name_sts_dtype = expected_vals, expected_name_sts_dtype CacheCollector.start(var_vals) self.ret = self.fxn(*args, **kwargs) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 76d31b9dd7..9a2163bac7 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -76,7 +76,6 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]: # **** lazy operations **** -def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(root.op.src[0]) if getattr(root, 'op', None) and len(root.op.src) == 1 and isinstance(root.op.src[0], LazyBuffer) else root def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index f9a6bd4ed4..ece91d2ec0 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -37,7 +37,7 @@ class MemBuffer: @dataclass(frozen=True) class ConstBuffer: - val: Any + val: Union[int, float] dtype: DType st: ShapeTracker @@ -152,22 +152,22 @@ class JitItem: rawbufs: List[Optional[RawBuffer]] class BatchExecutor: - def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]): + def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]): self.jit_cache: List[JitItem] = jit_cache - self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {} + self.input_replace: Dict[Tuple[int, int], int] = {} self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0) for j,ji in enumerate(jit_cache): if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored self.op_estimate += ji.prg.op_estimate self.mem_estimate += ji.prg.mem_estimate for i,a in enumerate(ji.rawbufs): - if a in [v for v in input_rawbuffers.values()]: - self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0] - assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found" + if a in input_rawbuffers: + self.input_replace[(j,i)] = input_rawbuffers.index(a) + assert len(set(self.input_replace.values())) == len(input_rawbuffers), "some input tensors not found" self.clear_jit_inputs() - def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False): - for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name] + def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False): + for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] for ji in self.jit_cache: ji.prg(ji.rawbufs, var_vals, jit=True) self.clear_jit_inputs() diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index db7ebf4b15..1043744b81 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -1,7 +1,7 @@ # pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch import os, subprocess, pathlib, ctypes, tempfile import Metal, libdispatch -from typing import List, Any, Tuple, Dict, Union, Set, cast +from typing import List, Any, Tuple, Dict, Set, cast from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup from tinygrad.ops import Compiled, BatchExecutor, JitItem, CompiledASTRunner, update_stats @@ -85,7 +85,7 @@ class MetalProgram: METAL.mtl_buffers_in_flight.append(command_buffer) class MetalBatchExecutor(BatchExecutor): - def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]): + def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int]): super().__init__(jit_cache, input_rawbuffers, var_vals) # create metal batch exec @@ -127,12 +127,12 @@ class MetalBatchExecutor(BatchExecutor): self.command_buffer: Any = None self.int_buf_view = self.int_buf.buffer_view() # TODO: this is metal syncing when it doesn't need to - def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False): + def __call__(self, input_rawbuffers: List[RawBuffer], var_vals: Dict[Variable, int], wait=False): # NOTE: you at least can't update the ints if this is running if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted() - all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers.values()] - for (j,i),input_name in self.input_replace.items(): - self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_name]._buf, 0, i) + all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers] + for (j,i),input_idx in self.input_replace.items(): + self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i) for j in self.input_has_variable_dims: global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals) self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))