mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
support DEV= to specify device (#11351)
This commit is contained in:
@@ -512,6 +512,24 @@ class TestTinygrad(unittest.TestCase):
|
||||
subprocess.run([f'NPY=1 {Device.DEFAULT}=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
|
||||
shell=True, check=True)
|
||||
|
||||
if Device.DEFAULT != "CPU":
|
||||
# setting multiple devices fail
|
||||
with self.assertRaises(subprocess.CalledProcessError):
|
||||
subprocess.run([f'{Device.DEFAULT}=1 CPU=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
|
||||
shell=True, check=True)
|
||||
|
||||
# setting device via DEV
|
||||
subprocess.run([f'DEV={Device.DEFAULT.capitalize()} python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
|
||||
shell=True, check=True)
|
||||
subprocess.run([f'DEV={Device.DEFAULT.lower()} python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
|
||||
shell=True, check=True)
|
||||
subprocess.run([f'DEV={Device.DEFAULT.upper()} python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
|
||||
shell=True, check=True)
|
||||
|
||||
with self.assertRaises(subprocess.CalledProcessError):
|
||||
subprocess.run([f'DEV={Device.DEFAULT} CPU=1 python3 -c "from tinygrad import Device; assert Device.DEFAULT == \\"{Device.DEFAULT}\\""'],
|
||||
shell=True, check=True)
|
||||
|
||||
def test_no_attributeerror_after_apply_uop_exception(self):
|
||||
try:
|
||||
Tensor.arange(4).reshape(3,2)
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from typing import Any, Generic, TypeVar, Iterator
|
||||
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, \
|
||||
colored, Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent
|
||||
colored, Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, cpu_events, ProfileEvent, dedup
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -37,7 +37,8 @@ class _Device:
|
||||
with contextlib.suppress(Exception): yield self[device].device
|
||||
@functools.cached_property
|
||||
def DEFAULT(self) -> str:
|
||||
from_env = [d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1]
|
||||
dev = [dev] if (dev:=getenv("DEV", "").upper()) else []
|
||||
from_env = dedup(dev + [d for d in self._devices if d not in ["DISK", "NPY"] and getenv(d) == 1])
|
||||
assert len(from_env) < 2, f"multiple devices set in env: {from_env}"
|
||||
if len(from_env) == 1: return from_env[0]
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user