mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-29 00:38:10 -05:00
clean up _prepare_jit_inputs [pr] (#7457)
removed an unnecessary cast and reordered a bit
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
|
||||
import functools, itertools, collections
|
||||
import functools, collections
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition
|
||||
@@ -172,17 +172,15 @@ class CapturedJit(Generic[ReturnType]):
|
||||
return self.ret
|
||||
|
||||
def _prepare_jit_inputs(args, kwargs):
|
||||
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
|
||||
[(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
|
||||
if input_tensors: Tensor.realize(*[t for _,t in input_tensors])
|
||||
names: List[Union[int, str]] = [name for name,_ in input_tensors]
|
||||
lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for _,t in input_tensors])
|
||||
st_varvals_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
|
||||
input_tensors: List[Tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
||||
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
||||
if tensors: Tensor.realize(*tensors)
|
||||
lbs: List[LazyBuffer] = flatten([t.lazydata.lbs for t in tensors])
|
||||
input_buffers: List[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
|
||||
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
||||
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
|
||||
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, UOp))])
|
||||
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
|
||||
st_varval_dtype_device = [(*lb.st.unbind(), lb.dtype, lb.device) for lb in lbs]
|
||||
var_vals = merge_dicts([x[1] for x in st_varval_dtype_device] + [dict(v.unbind() for v in (args + tuple(kwargs.values())) if isinstance(v, UOp))])
|
||||
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varval_dtype_device]
|
||||
return input_buffers, var_vals, names, st_vars_dtype_device
|
||||
|
||||
class TinyJit(Generic[ReturnType]):
|
||||
|
||||
Reference in New Issue
Block a user