diff --git a/test/test_pickle.py b/test/test_pickle.py index 4e6a972fe9..75ae39b2e1 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -45,7 +45,7 @@ class TestPickle(unittest.TestCase): def test_pickle_jit(self): @TinyJit - def add(a, b): return a+b+1 + def add(a, b): return a.sum()+b+1 for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10)) st = pickle.dumps(add) del add @@ -55,7 +55,7 @@ class TestPickle(unittest.TestCase): y = Tensor.ones(10, 10).contiguous().realize() print("post jit") out = add_fxn(x, y) - np.testing.assert_equal(out.numpy(), 3) + np.testing.assert_equal(out.numpy(), 102) def test_pickle_schedule(self): a = Tensor([1,2]) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 9b4abeef7b..f572b6575d 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -33,10 +33,10 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer] for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers))) max_batch_size *= 2 - if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}") + if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}") except GraphException as e: graphed_jit_cache.extend(current_batch) - if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}") + if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}") current_batch = [] current_device = None @@ -128,6 +128,47 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency return list({id(x):x for x in wait_nodes}.values()) +ReturnType = TypeVar('ReturnType') +@dataclass +class CapturedJit(Generic[ReturnType]): + ret: Any # includes the Tensors or any other returned object + jit_cache: List[ExecItem] + input_replace: Dict[Tuple[int, int], int] + extra_view_inputs: List[Tuple[int, int, str, int, DType]] + expected_names: List[Union[int, str]] + expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] + + def __reduce__(self): + return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs, + self.expected_names, self.expected_st_vars_dtype_device) + + def __post_init__(self): + self._jit_cache: List[ExecItem] = self.jit_cache + self._input_replace: Dict[Tuple[int, int], int] = self.input_replace + self._graphed = False + self._clear_inputs() + + def _clear_inputs(self): + for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None + + # jit exec + def __call__(self, input_buffers:List[Buffer], var_vals:Dict[Variable, int]) -> ReturnType: + # assign inputs + for idx, offset, device, size, dtype in self.extra_view_inputs: + input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated()) + for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx] + + # Condense the items into a graph executor. + if JIT < 2 and not self._graphed: + self._jit_cache = apply_graph_to_jit(self._jit_cache, input_buffers, var_vals) + self._input_replace = get_input_replace(self._jit_cache, input_buffers) + self._graphed = True + + if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels") + for ei in self._jit_cache: ei.run(var_vals, jit=True) + self._clear_inputs() + return self.ret + def _prepare_jit_inputs(args, kwargs): input_tensors: List[Tuple[Union[int, str], Tensor]] = \ [(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor] @@ -142,42 +183,12 @@ def _prepare_jit_inputs(args, kwargs): st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device] return input_buffers, var_vals, names, st_vars_dtype_device -ReturnType = TypeVar('ReturnType') - -@dataclass(frozen=True) -class CapturedJit(Generic[ReturnType]): - ret: Any # includes the Tensors or any other returned object - jit_cache: List[ExecItem] - input_replace: Dict[Tuple[int, int], int] - extra_view_inputs: List[Tuple[int, int, str, int, DType]] - expected_names: List[Union[int, str]] - expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] - def __post_init__(self): self.clear_jit_inputs() - - def clear_jit_inputs(self): - for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None - - def __call__(self, *args, **kwargs) -> ReturnType: - input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs) - assert self.expected_names == names, f"args mismatch in JIT: {self.expected_names=} != {names}" - assert self.expected_st_vars_dtype_device == st_vars_dtype_device, \ - f"args mismatch in JIT: {self.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}" - - # jit exec - for idx, offset, device, size, dtype in self.extra_view_inputs: - input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated()) - for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_buffers[input_idx] - if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels") - for ei in self.jit_cache: ei.run(var_vals, jit=True) - - # cleanup - self.clear_jit_inputs() - return self.ret - class TinyJit(Generic[ReturnType]): - def __init__(self, fxn:Callable[..., ReturnType]): + def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None): + assert fxn or captured, "need either a function or a CapturedJit" self.fxn = fxn - self.reset() + self.captured: Optional[CapturedJit] = captured + self.cnt: int = 2 if self.fxn is None else 0 def add_buffer(self, b:Buffer) -> Buffer: if found:=self._buffer_replace.get(b, None): return found @@ -192,32 +203,33 @@ class TinyJit(Generic[ReturnType]): self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None])) def reset(self): - self.cnt: int = 0 - self.captured: Optional[CapturedJit] = None + assert self.fxn is not None, "can't reset without function" + self.cnt = 0 + self.captured = None def __reduce__(self): assert self.captured is not None, "can't pickle an uncaptured JIT" - return CapturedJit, tuple(self.captured.__dict__.values()) + return self.__class__, (None, self.captured) # keep legacy code working @property - def jit_cache(self) -> List[ExecItem]: return self.captured.jit_cache if self.captured is not None else [] + def jit_cache(self) -> List[ExecItem]: return self.captured._jit_cache if self.captured is not None else [] @property - def input_replace(self) -> Dict[Tuple[int, int], int]: return self.captured.input_replace if self.captured is not None else {} + def input_replace(self) -> Dict[Tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {} def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods def __call__(self, *args, **kwargs) -> ReturnType: - if self.captured is not None: return self.captured(*args, **kwargs) - input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs) if not JIT or self.cnt == 0: # jit ignore + assert self.fxn is not None with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value): ret = self.fxn(*args, **kwargs) if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:]) elif self.cnt == 1: # jit capture + assert self.fxn is not None if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}") self._jit_cache: List[ExecItem] = [] self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary() @@ -234,6 +246,7 @@ class TinyJit(Generic[ReturnType]): if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs") # track inputs that are views of buffers + # TODO: eventually expected_buffers should live in ExecItem extra_view_inputs: List[Tuple[int, int, str, int, DType]] = [] for item in jit_cache: for b in item.bufs: @@ -247,14 +260,18 @@ class TinyJit(Generic[ReturnType]): assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ") jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache] - # Condense the items into a graph executor. - if JIT < 2: jit_cache = apply_graph_to_jit(jit_cache, input_buffers, var_vals) - input_replace = get_input_replace(jit_cache, input_buffers) if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found") # set this for next run self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device) + elif self.cnt >= 2: + # jit exec + assert self.captured is not None + assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}" + assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \ + f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}" + ret = self.captured(input_buffers, var_vals) self.cnt += 1 return ret