graph remove input buffer references (#4100)

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
andresgit
2024-04-08 23:49:16 +03:00
committed by GitHub
parent 078d841479
commit 7fd12aba85
3 changed files with 9 additions and 0 deletions

View File

@@ -55,6 +55,9 @@ class CUDAGraph(MultiDeviceJITGraph):
self.instance = init_c_var(cuda.CUgraphExec(), lambda x: check(cuda.cuGraphInstantiate_v2(ctypes.byref(x), self.graph, None, None, 0)))
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# Update rawbuffers in the c_args struct.
for (j,i),input_idx in self.input_replace.items():