From b1d1816f43e7922932e009ee8c3b16b77358e022 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 13 Sep 2025 23:38:09 +0300 Subject: [PATCH] device: fix envvars (#12159) --- test/unit/test_device.py | 9 +++++++++ tinygrad/device.py | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/test/unit/test_device.py b/test/unit/test_device.py index 8ab43a17f0..1db0595348 100644 --- a/test/unit/test_device.py +++ b/test/unit/test_device.py @@ -64,6 +64,15 @@ class TestDevice(unittest.TestCase): shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"}) else: self.skipTest("only run on CPU/AMD") + def test_compiler_envvar(self): + d = Device[Device.DEFAULT] + dname = Device.DEFAULT.split(':')[0].upper() + assert d._get_compiler_envvar(type("Compiler", (), {})) == f"{dname}_COMPILER" + assert d._get_compiler_envvar(type("LLVMCompiler", (), {})) == f"{dname}_LLVM" + assert d._get_compiler_envvar(type("RandomCompiler", (), {})) == f"{dname}_RANDOM" + assert d._get_compiler_envvar(type(f"{dname}Compiler", (), {})) == f"{dname}_{dname}COMPILER" # do not repeat device name alone + assert d._get_compiler_envvar(type(f"{dname}LLVMCompiler", (), {})) == f"{dname}_LLVM" # do not repeat device name + class MockCompiler(Compiler): def __init__(self, key): super().__init__(key) def compile(self, src) -> bytes: return src.encode() diff --git a/tinygrad/device.py b/tinygrad/device.py index 7a0d8b8ab2..bc0f6eb64c 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -294,7 +294,8 @@ class Compiled: if DEBUG >= 1: print(f"{self.device}: using {self.compiler.__class__.__name__}") def _get_compiler_envvar(self, c): - return f"{(devname:=self.device.split(':')[0].upper())}_{unwrap_class_type(c).__name__.removesuffix('Compiler').removeprefix(devname).upper()}" + compiler_name = f"{unwrap_class_type(c).__name__.upper().removesuffix('COMPILER').removeprefix(devname:=self.device.split(':')[0].upper())}" + return f"{devname}_{compiler_name if len(compiler_name) > 0 else unwrap_class_type(c).__name__.upper()}" def _get_available_compilers(self, compilers) -> Iterator[tuple[Renderer, Compiler]]: for renderer, compiler in compilers: