diff --git a/tinygrad/runtime/support/cuda.py b/tinygrad/runtime/support/cuda.py index ed2e112610..42724a334a 100644 --- a/tinygrad/runtime/support/cuda.py +++ b/tinygrad/runtime/support/cuda.py @@ -1,4 +1,4 @@ -import ctypes.util, os, platform +import os, platform from ctypes.util import find_library CUDA_PATH: str | None @@ -10,8 +10,8 @@ if platform.system() == "Windows": def find_nv_dll(glob_pattern): cuda_bin = os.path.join(os.environ.get("CUDA_PATH", ""), "bin") matches = glob.glob(os.path.join(cuda_bin, glob_pattern)) - return matches[0] if matches else None - + return matches[0] if matches else None + CUDA_PATH = find_library('nvcuda') NVRTC_PATH = find_nv_dll("nvrtc64_*.dll") NVJITLINK_PATH = find_nv_dll("nvJitLink_*.dll")