mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
device: fix envvars (#12159)
This commit is contained in:
@@ -64,6 +64,15 @@ class TestDevice(unittest.TestCase):
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"})
|
||||
else: self.skipTest("only run on CPU/AMD")
|
||||
|
||||
def test_compiler_envvar(self):
|
||||
d = Device[Device.DEFAULT]
|
||||
dname = Device.DEFAULT.split(':')[0].upper()
|
||||
assert d._get_compiler_envvar(type("Compiler", (), {})) == f"{dname}_COMPILER"
|
||||
assert d._get_compiler_envvar(type("LLVMCompiler", (), {})) == f"{dname}_LLVM"
|
||||
assert d._get_compiler_envvar(type("RandomCompiler", (), {})) == f"{dname}_RANDOM"
|
||||
assert d._get_compiler_envvar(type(f"{dname}Compiler", (), {})) == f"{dname}_{dname}COMPILER" # do not repeat device name alone
|
||||
assert d._get_compiler_envvar(type(f"{dname}LLVMCompiler", (), {})) == f"{dname}_LLVM" # do not repeat device name
|
||||
|
||||
class MockCompiler(Compiler):
|
||||
def __init__(self, key): super().__init__(key)
|
||||
def compile(self, src) -> bytes: return src.encode()
|
||||
|
||||
@@ -294,7 +294,8 @@ class Compiled:
|
||||
if DEBUG >= 1: print(f"{self.device}: using {self.compiler.__class__.__name__}")
|
||||
|
||||
def _get_compiler_envvar(self, c):
|
||||
return f"{(devname:=self.device.split(':')[0].upper())}_{unwrap_class_type(c).__name__.removesuffix('Compiler').removeprefix(devname).upper()}"
|
||||
compiler_name = f"{unwrap_class_type(c).__name__.upper().removesuffix('COMPILER').removeprefix(devname:=self.device.split(':')[0].upper())}"
|
||||
return f"{devname}_{compiler_name if len(compiler_name) > 0 else unwrap_class_type(c).__name__.upper()}"
|
||||
|
||||
def _get_available_compilers(self, compilers) -> Iterator[tuple[Renderer, Compiler]]:
|
||||
for renderer, compiler in compilers:
|
||||
|
||||
Reference in New Issue
Block a user