mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
device: respect compiler ContextVars (#13523)
* device: envvars for cc * fix * fix * x * um * fix * remote * em * cleanup * typing * fix * debug * lvp? * ugh * singl * rm * lol * fix * ? * this? * why? * rev * mod test * l
This commit is contained in:
@@ -42,12 +42,10 @@ class TestDevice(unittest.TestCase):
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_LLVM": "1"})
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangJITCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_LLVM": "0"})
|
||||
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "0", "CPU_LLVM": "0"})
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, CPULLVMCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "0"})
|
||||
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CLANGJIT": "1", "CPU_LLVM": "1"})
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CC": "LLVM"})
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, ClangJITCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "CPU", "CPU_CC": "CLANGJIT"})
|
||||
elif Device.DEFAULT == "AMD":
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler, AMDLLVMCompiler
|
||||
try: _, _ = HIPCompiler(Device[Device.DEFAULT].arch), AMDLLVMCompiler(Device[Device.DEFAULT].arch)
|
||||
@@ -64,14 +62,20 @@ class TestDevice(unittest.TestCase):
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"})
|
||||
else: self.skipTest("only run on CPU/AMD")
|
||||
|
||||
def test_compiler_envvar(self):
|
||||
d = Device[Device.DEFAULT]
|
||||
dname = Device.DEFAULT.split(':')[0].upper()
|
||||
assert d._get_compiler_envvar(type("Compiler", (), {})) == f"{dname}_COMPILER"
|
||||
assert d._get_compiler_envvar(type("LLVMCompiler", (), {})) == f"{dname}_LLVM"
|
||||
assert d._get_compiler_envvar(type("RandomCompiler", (), {})) == f"{dname}_RANDOM"
|
||||
assert d._get_compiler_envvar(type(f"{dname}Compiler", (), {})) == f"{dname}_{dname}COMPILER" # do not repeat device name alone
|
||||
assert d._get_compiler_envvar(type(f"{dname}LLVMCompiler", (), {})) == f"{dname}_LLVM" # do not repeat device name
|
||||
@unittest.skipIf((WIN and CI) or (not Device.DEFAULT == "CPU"), "skipping windows test")
|
||||
def test_env_online(self):
|
||||
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
|
||||
try: _, _ = CPULLVMCompiler(), ClangJITCompiler()
|
||||
except Exception as e: self.skipTest(f"skipping compiler test: not all compilers: {e}")
|
||||
|
||||
with Context(CPU_LLVM=1):
|
||||
inst = Device["CPU"].compiler
|
||||
self.assertIsInstance(Device["CPU"].compiler, CPULLVMCompiler)
|
||||
with Context(CPU_LLVM=0):
|
||||
self.assertIsInstance(Device["CPU"].compiler, ClangJITCompiler)
|
||||
with Context(CPU_LLVM=1):
|
||||
self.assertIsInstance(Device["CPU"].compiler, CPULLVMCompiler)
|
||||
assert inst is Device["CPU"].compiler # cached
|
||||
|
||||
class MockCompiler(Compiler):
|
||||
def __init__(self, key): super().__init__(key)
|
||||
|
||||
Reference in New Issue
Block a user