device: fix envvars (#12159)

This commit is contained in:
nimlgen
2025-09-13 23:38:09 +03:00
committed by GitHub
parent 19d9d29b7e
commit b1d1816f43
2 changed files with 11 additions and 1 deletions

View File

@@ -64,6 +64,15 @@ class TestDevice(unittest.TestCase):
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"})
else: self.skipTest("only run on CPU/AMD")
def test_compiler_envvar(self):
d = Device[Device.DEFAULT]
dname = Device.DEFAULT.split(':')[0].upper()
assert d._get_compiler_envvar(type("Compiler", (), {})) == f"{dname}_COMPILER"
assert d._get_compiler_envvar(type("LLVMCompiler", (), {})) == f"{dname}_LLVM"
assert d._get_compiler_envvar(type("RandomCompiler", (), {})) == f"{dname}_RANDOM"
assert d._get_compiler_envvar(type(f"{dname}Compiler", (), {})) == f"{dname}_{dname}COMPILER" # do not repeat device name alone
assert d._get_compiler_envvar(type(f"{dname}LLVMCompiler", (), {})) == f"{dname}_LLVM" # do not repeat device name
class MockCompiler(Compiler):
def __init__(self, key): super().__init__(key)
def compile(self, src) -> bytes: return src.encode()

View File

@@ -294,7 +294,8 @@ class Compiled:
if DEBUG >= 1: print(f"{self.device}: using {self.compiler.__class__.__name__}")
def _get_compiler_envvar(self, c):
return f"{(devname:=self.device.split(':')[0].upper())}_{unwrap_class_type(c).__name__.removesuffix('Compiler').removeprefix(devname).upper()}"
compiler_name = f"{unwrap_class_type(c).__name__.upper().removesuffix('COMPILER').removeprefix(devname:=self.device.split(':')[0].upper())}"
return f"{devname}_{compiler_name if len(compiler_name) > 0 else unwrap_class_type(c).__name__.upper()}"
def _get_available_compilers(self, compilers) -> Iterator[tuple[Renderer, Compiler]]:
for renderer, compiler in compilers: