From e5745c1a0dbe03425b40b16130db21c95cc292d4 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 21 Mar 2024 15:21:55 +0300 Subject: [PATCH] fix nan on multigpus cuda (#3854) --- tinygrad/runtime/graph/cuda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/runtime/graph/cuda.py b/tinygrad/runtime/graph/cuda.py index 9dec8d6fee..f8eda4f34b 100644 --- a/tinygrad/runtime/graph/cuda.py +++ b/tinygrad/runtime/graph/cuda.py @@ -68,7 +68,7 @@ class CUDAGraph: if hasattr(self, 'graph'): check(cuda.cuGraphDestroy(self.graph)) if hasattr(self, 'instance'): check(cuda.cuGraphExecDestroy(self.instance)) - def set_device(self): pass + def set_device(self): check(cuda.cuCtxSetCurrent(self.device.context)) def encode_args_info(self): return (cuda.CUdeviceptr_v2, (1,2,0)) def graph_create(self): return init_c_var(cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))) def graph_instantiate(self, graph):