support DEV= to specify device (#11351)

This commit is contained in:
chenyu
2025-07-23 17:40:55 -04:00
committed by GitHub
parent 76a2ddbd78
commit 5b570196e4
2 changed files with 21 additions and 2 deletions

View File

@@ -512,6 +512,24 @@ class TestTinygrad(unittest.TestCase):
subprocess.run([f'NPY=1 {Device.DEFAULT}=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
shell=True, check=True)
if Device.DEFAULT != "CPU":
# setting multiple devices fail
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run([f'{Device.DEFAULT}=1 CPU=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
shell=True, check=True)
# setting device via DEV
subprocess.run([f'DEV={Device.DEFAULT.capitalize()} python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
shell=True, check=True)
subprocess.run([f'DEV={Device.DEFAULT.lower()} python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
shell=True, check=True)
subprocess.run([f'DEV={Device.DEFAULT.upper()} python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
shell=True, check=True)
with self.assertRaises(subprocess.CalledProcessError):
subprocess.run([f'DEV={Device.DEFAULT} CPU=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
shell=True, check=True)
def test_no_attributeerror_after_apply_uop_exception(self):
try:
Tensor.arange(4).reshape(3,2)

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from typing import Any, Generic, TypeVar, Iterator
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal, time
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, \
colored, Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent
colored, Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent, dedup
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
@@ -37,7 +37,8 @@ class _Device:
with contextlib.suppress(Exception): yield self[device].device
@functools.cached_property
def DEFAULT(self) -> str:
from_env = [d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1]
dev = [dev] if (dev:=getenv("DEV", "").upper()) else []
from_env = dedup(dev + [d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1])
assert len(from_env) < 2, f"multiple devices set in env: {from_env}"
if len(from_env) == 1: return from_env[0]
try: