device: respect compiler ContextVars (#13523)

* device: envvars for cc

* fix

* fix

* x

* um

* fix

* remote

* em

* cleanup

* typing

* fix

* debug

* lvp?

* ugh

* singl

* rm

* lol

* fix

* ?

* this?

* why?

* rev

* mod test

* l
This commit is contained in:
nimlgen
2025-12-02 14:42:04 +03:00
committed by GitHub
parent 1b7dbfb37f
commit 77a76d1b13
18 changed files with 128 additions and 88 deletions

View File

@@ -42,12 +42,10 @@ class TestDevice(unittest.TestCase):
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_LLVM": "1"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangJITCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_LLVM": "0"})
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "0", "CPU_LLVM": "0"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, CPULLVMCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "0"})
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "1", "CPU_LLVM": "1"})
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CC": "LLVM"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangJITCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CC": "CLANGJIT"})
elif Device.DEFAULT == "AMD":
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
try: _, _ = HIPCompiler(Device[Device.DEFAULT].arch), AMDLLVMCompiler(Device[Device.DEFAULT].arch)
@@ -64,14 +62,20 @@ class TestDevice(unittest.TestCase):
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"})
else: self.skipTest("only run on CPU/AMD")
def test_compiler_envvar(self):
d = Device[Device.DEFAULT]
dname = Device.DEFAULT.split(':')[0].upper()
assert d._get_compiler_envvar(type("Compiler", (), {})) == f"{dname}_COMPILER"
assert d._get_compiler_envvar(type("LLVMCompiler", (), {})) == f"{dname}_LLVM"
assert d._get_compiler_envvar(type("RandomCompiler", (), {})) == f"{dname}_RANDOM"
assert d._get_compiler_envvar(type(f"{dname}Compiler", (), {})) == f"{dname}_{dname}COMPILER" # do not repeat device name alone
assert d._get_compiler_envvar(type(f"{dname}LLVMCompiler", (), {})) == f"{dname}_LLVM" # do not repeat device name
@unittest.skipIf((WIN and CI) or (not Device.DEFAULT == "CPU"), "skipping windows test")
def test_env_online(self):
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
try: _, _ = CPULLVMCompiler(), ClangJITCompiler()
except Exception as e: self.skipTest(f"skipping compiler test: not all compilers: {e}")
with Context(CPU_LLVM=1):
inst = Device["CPU"].compiler
self.assertIsInstance(Device["CPU"].compiler, CPULLVMCompiler)
with Context(CPU_LLVM=0):
self.assertIsInstance(Device["CPU"].compiler, ClangJITCompiler)
with Context(CPU_LLVM=1):
self.assertIsInstance(Device["CPU"].compiler, CPULLVMCompiler)
assert inst is Device["CPU"].compiler # cached
class MockCompiler(Compiler):
def __init__(self, key): super().__init__(key)