graph remove input buffer references (#4100)

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
andresgit
2024-04-08 23:49:16 +03:00
committed by GitHub
parent 078d841479
commit 7fd12aba85
3 changed files with 9 additions and 0 deletions

View File

@@ -55,6 +55,9 @@ class CUDAGraph(MultiDeviceJITGraph):
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# Update rawbuffers in the c_args struct.
for (j,i),input_idx in self.input_replace.items():

View File

@@ -112,6 +112,9 @@ class HSAGraph(MultiDeviceJITGraph):
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# Wait and restore signals
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)

View File

@@ -51,6 +51,9 @@ class MetalGraph:
self.command_buffer: Any = None
if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# NOTE: you at least can't update the ints if this is running
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)