device: fix envvars (#12159)

This commit is contained in:
nimlgen
2025-09-13 23:38:09 +03:00
committed by GitHub
parent 19d9d29b7e
commit b1d1816f43
2 changed files with 11 additions and 1 deletions

View File

@@ -64,6 +64,15 @@ class TestDevice(unittest.TestCase):
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"})
else: self.skipTest("only run on CPU/AMD")
def test_compiler_envvar(self):
d = Device[Device.DEFAULT]
dname = Device.DEFAULT.split(':')[0].upper()
assert d._get_compiler_envvar(type("Compiler", (), {})) == f"{dname}_COMPILER"
assert d._get_compiler_envvar(type("LLVMCompiler", (), {})) == f"{dname}_LLVM"
assert d._get_compiler_envvar(type("RandomCompiler", (), {})) == f"{dname}_RANDOM"
assert d._get_compiler_envvar(type(f"{dname}Compiler", (), {})) == f"{dname}_{dname}COMPILER" # do not repeat device name alone
assert d._get_compiler_envvar(type(f"{dname}LLVMCompiler", (), {})) == f"{dname}_LLVM" # do not repeat device name
class MockCompiler(Compiler):
def __init__(self, key): super().__init__(key)
def compile(self, src) -> bytes: return src.encode()