add get_available_backends (#6771)

* lol

* 1 less line lmfao

* something like this?

* comment

* pylint

* just iterator

* backends -> devices

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
geohotstan
2024-09-30 08:58:04 +08:00
committed by GitHub
parent 3c15e64273
commit 282abb4234

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import multiprocessing, decimal, statistics, random
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Iterator
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
from tinygrad.dtype import DType, ImageDType
@@ -27,16 +27,17 @@ class _Device:
return ret
@property
def default(self) -> Compiled: return self[self.DEFAULT]
def get_available_devices(self) -> Iterator[str]:
for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]:
with contextlib.suppress(Exception): yield self[device].dname
@functools.cached_property
def DEFAULT(self) -> str:
if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env
for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]:
try:
if self[device]:
os.environ[device] = "1" # we set this in environment for spawned children
return device
except Exception: pass
raise RuntimeError("no usable devices")
try:
device = next(self.get_available_devices())
os.environ[device] = "1" # we set this in environment for spawned children
return device
except StopIteration as exc: raise RuntimeError("no usable devices") from exc
Device = _Device()
# **************** Buffer + Allocators ****************