diff --git a/tinygrad/device.py b/tinygrad/device.py index a6bd2ea8c2..5449429f10 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -2,7 +2,7 @@ from __future__ import annotations import multiprocessing, decimal, statistics, random from dataclasses import dataclass, replace from collections import defaultdict -from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type +from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Iterator import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE from tinygrad.dtype import DType, ImageDType @@ -27,16 +27,17 @@ class _Device: return ret @property def default(self) -> Compiled: return self[self.DEFAULT] + def get_available_devices(self) -> Iterator[str]: + for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]: + with contextlib.suppress(Exception): yield self[device].dname @functools.cached_property def DEFAULT(self) -> str: if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env - for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]: - try: - if self[device]: - os.environ[device] = "1" # we set this in environment for spawned children - return device - except Exception: pass - raise RuntimeError("no usable devices") + try: + device = next(self.get_available_devices()) + os.environ[device] = "1" # we set this in environment for spawned children + return device + except StopIteration as exc: raise RuntimeError("no usable devices") from exc Device = _Device() # **************** Buffer + Allocators ****************