diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index bff3c51d33..3cbf7eaabe 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -55,7 +55,7 @@ class TestSymbolicJit(unittest.TestCase): symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy() expected = f(a, b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert len(jf.jit_cache) == 2 + assert len(jf.jit_cache) == 2 or getattr(Device[Device.DEFAULT], "graph", None) def test_attention(self): def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize() @@ -68,7 +68,7 @@ class TestSymbolicJit(unittest.TestCase): symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy() expected = f(q, k, v).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - assert len(jf.jit_cache) == 6 + assert len(jf.jit_cache) == 6 or getattr(Device[Device.DEFAULT], "graph", None) def test_cat_dim0(self): def f(a, b): return a.cat(b, dim=0).realize() diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index f330f9a283..aa629237eb 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -169,7 +169,7 @@ class Linearizer(Kernel): if isinstance(buf, MemBuffer): self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype)) # add var vals - for var in sorted(vars_from_ast(self.ast)): + for var in vars_from_ast(self.ast): assert var.expr is not None self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32)) # define local buffers diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 9a2163bac7..ad96faee4d 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -79,7 +79,8 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]: def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x) -def vars_from_ast(ast:LazyOp) -> Set[Variable]: return set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()) +# NOTE: this is the canonical order +def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr)) lazycache: WeakValueDictionary = WeakValueDictionary() def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 9c9c846e66..bbd66a9517 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,7 +1,7 @@ from __future__ import annotations import importlib, inspect, functools, pathlib, time, re from enum import Enum, auto -from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, Set +from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int from tinygrad.runtime.lib import RawBuffer from tinygrad.shape.symbolic import Variable, sym_infer, sint @@ -184,10 +184,8 @@ class Interpreted: def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs): if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, ast) - rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs] - if rawbufs[0] is None: rawbufs[0] = self.buffer.__new__(self.buffer) - self.method_cache[ast].exec(rawbufs, var_vals) - output.realized = rawbufs[0] + output.realized = output.output_buffer if output.output_buffer is not None else self.buffer.__new__(self.buffer) + self.method_cache[ast].exec([output.realized] + [x.realized for x in inputs], var_vals) def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner: if DEBUG >= 3: @@ -236,7 +234,7 @@ class CompiledASTRunner(JITRunner): if DEBUG >= 4: print(prg) self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = \ name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} - self.vars: Set[Variable] = set() + self.vars: List[Variable] = [] if ast: info = get_lazyop_info(ast) self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate @@ -255,8 +253,6 @@ class CompiledASTRunner(JITRunner): return global_size, local_size def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: - # filter the var_vals - var_vals = {k:var_vals[k] for k in sorted(self.vars)} global_size, local_size = self.launch_dims(var_vals) if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] # TODO: this is copied from get_program @@ -266,7 +262,7 @@ class CompiledASTRunner(JITRunner): lra = self.runtime_args.copy() if global_size: lra['global_size'] = global_size if local_size and 'local_size' not in lra: lra['local_size'] = local_size - et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=wait or DEBUG>=2) + et = self.clprg(*rawbufs, *[var_vals[k] for k in self.vars], **lra, wait=wait or DEBUG>=2) update_stats(self.display_name if self.display_name is not None else self.name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra) return et diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 350996bc7c..c9000f42b0 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -114,7 +114,7 @@ class MetalGraph: icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i) if i == 0: write_resources.append(b._buf) else: read_resources.append(b._buf) - var_vals_keys = sorted(var_vals.keys()) + var_vals_keys = list(var_vals.keys()) for i,v in enumerate(prg.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i) global_size, local_size = prg.launch_dims(var_vals)