diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index a66de9fa22..4699e33cc8 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -55,6 +55,9 @@ class CUDAGraph(MultiDeviceJITGraph): self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0))) + # clear jit inputs to allow their memory to be freed/reused + for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: # Update rawbuffers in the c_args struct. for (j,i),input_idx in self.input_replace.items(): diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index 5161b9100e..8ac52f3296 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -112,6 +112,9 @@ class HSAGraph(MultiDeviceJITGraph): for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0) hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0) + # clear jit inputs to allow their memory to be freed/reused + for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: # Wait and restore signals hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index 49ab94c407..40ffe9d70f 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -51,6 +51,9 @@ class MetalGraph: self.command_buffer: Any = None if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i') + # clear jit inputs to allow their memory to be freed/reused + for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None + def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: # NOTE: you at least can't update the ints if this is running if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)