fix nan on multigpus cuda (#3854)

This commit is contained in:
nimlgen
2024-03-21 15:21:55 +03:00
committed by GitHub
parent 4e0819e40b
commit e5745c1a0d

View File

@@ -68,7 +68,7 @@ class CUDAGraph:
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
def set_device(self): pass
def set_device(self): check(cuda.cuCtxSetCurrent(self.device.context))
def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0))
def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
def graph_instantiate(self, graph):