diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 49a4c093ef..6a0d5cc513 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any -import functools, itertools, collections +import functools, collections from tinygrad.tensor import Tensor from tinygrad.engine.lazy import LazyBuffer from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition @@ -172,17 +172,15 @@ class CapturedJit(Generic[ReturnType]): return self.ret def _prepare_jit_inputs(args, kwargs): - input_tensors: List[Tuple[Union[int, str], Tensor]] = \ - [(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor] - if input_tensors: Tensor.realize(*[t for _,t in input_tensors]) - names: List[Union[int, str]] = [name for name,_ in input_tensors] - lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors]) - st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs] + input_tensors: List[Tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] + names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors] + if tensors: Tensor.realize(*tensors) + lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for t in tensors]) input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None] assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT" - var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \ - [dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, UOp))]) - st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device] + st_varval_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs] + var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))]) + st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device] return input_buffers, var_vals, names, st_vars_dtype_device class TinyJit(Generic[ReturnType]):