mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-10 22:54:59 -05:00
fix nan on multigpus cuda (#3854)
This commit is contained in:
@@ -68,7 +68,7 @@ class CUDAGraph:
|
||||
if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph))
|
||||
if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance))
|
||||
|
||||
def set_device(self): pass
|
||||
def set_device(self): check(cuda.cuCtxSetCurrent(self.device.context))
|
||||
def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0))
|
||||
def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)))
|
||||
def graph_instantiate(self, graph):
|
||||
|
||||
Reference in New Issue
Block a user