clean up _prepare_jit_inputs [pr] (#7457)

removed an unnecessary cast and reordered a bit
This commit is contained in:
chenyu
2024-10-31 20:41:02 -04:00
committed by GitHub
parent a21434504b
commit 036409266d

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
import functools, itertools, collections
import functools, collections
from tinygrad.tensor import Tensor
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition
@@ -172,17 +172,15 @@ class CapturedJit(Generic[ReturnType]):
return self.ret
def _prepare_jit_inputs(args, kwargs):
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
[(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
names: List[Union[int, str]] = [name for name,_ in input_tensors]
lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
input_tensors: List[Tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
if tensors: Tensor.realize(*tensors)
lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for t in tensors])
input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, UOp))])
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
st_varval_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
return input_buffers, var_vals, names, st_vars_dtype_device
class TinyJit(Generic[ReturnType]):