nv default for ampere & ada (#5329)

This commit is contained in:
nimlgen
2024-07-08 19:01:27 +03:00
committed by GitHub
parent 51d6f372e4
commit bb2222e488

View File

@@ -29,7 +29,7 @@ class _Device:
def DEFAULT(self) -> str:
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._devices, None) # type: ignore
if device_from_env: return device_from_env
for device in ["METAL", "AMD", "CUDA", "GPU", "CLANG", "LLVM"]:
for device in ["METAL", "AMD", "NV", "CUDA", "GPU", "CLANG", "LLVM"]:
try:
if self[device]:
os.environ[device] = "1" # we set this in environment for spawned children