mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-13 08:05:10 -05:00
graph remove input buffer references (#4100)
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -55,6 +55,9 @@ class CUDAGraph(MultiDeviceJITGraph):
|
||||
|
||||
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# Update rawbuffers in the c_args struct.
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
|
||||
@@ -112,6 +112,9 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
|
||||
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
|
||||
@@ -51,6 +51,9 @@ class MetalGraph:
|
||||
self.command_buffer: Any = None
|
||||
if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
|
||||
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# NOTE: you at least can't update the ints if this is running
|
||||
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
|
||||
Reference in New Issue
Block a user