mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
add get_available_backends (#6771)
* lol * 1 less line lmfao * something like this? * comment * pylint * just iterator * backends -> devices --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import multiprocessing, decimal, statistics, random
|
||||
from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type
|
||||
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Iterator
|
||||
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
|
||||
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, ProfileLogger, PROFILE
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
@@ -27,16 +27,17 @@ class _Device:
|
||||
return ret
|
||||
@property
|
||||
def default(self) -> Compiled: return self[self.DEFAULT]
|
||||
def get_available_devices(self) -> Iterator[str]:
|
||||
for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]:
|
||||
with contextlib.suppress(Exception): yield self[device].dname
|
||||
@functools.cached_property
|
||||
def DEFAULT(self) -> str:
|
||||
if (from_env:=next((d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1), None)): return from_env
|
||||
for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]:
|
||||
try:
|
||||
if self[device]:
|
||||
os.environ[device] = "1" # we set this in environment for spawned children
|
||||
return device
|
||||
except Exception: pass
|
||||
raise RuntimeError("no usable devices")
|
||||
try:
|
||||
device = next(self.get_available_devices())
|
||||
os.environ[device] = "1" # we set this in environment for spawned children
|
||||
return device
|
||||
except StopIteration as exc: raise RuntimeError("no usable devices") from exc
|
||||
Device = _Device()
|
||||
|
||||
# **************** Buffer + Allocators ****************
|
||||
|
||||
Reference in New Issue
Block a user