diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 9dec8d6fee..f8eda4f34b 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -68,7 +68,7 @@ class CUDAGraph: if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph)) if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance)) - def set_device(self): pass + def set_device(self): check(cuda.cuCtxSetCurrent(self.device.context)) def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0)) def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) def graph_instantiate(self, graph):