From b013244c3894db246292a75e06616445afcc6b5d Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 17 Dec 2025 12:23:46 -0400 Subject: [PATCH] fix local tests for AMD_LLVM (#13738) * fix local tests for AMD_LLVM * fix linters * skip that for now * fix segfault --- test/unit/test_device.py | 8 +++----- tinygrad/device.py | 4 ++-- tinygrad/renderer/llvmir.py | 2 +- tinygrad/runtime/support/compiler_cpu.py | 10 +++++++--- tinygrad/uop/decompositions.py | 6 ++---- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/test/unit/test_device.py b/test/unit/test_device.py index becc09fec0..cedbdf0d43 100644 --- a/test/unit/test_device.py +++ b/test/unit/test_device.py @@ -30,8 +30,6 @@ class TestDevice(unittest.TestCase): @unittest.skipIf(WIN and CI, "skipping windows test") # TODO: subproccess causes memory violation? def test_env_overwrite_default_compiler(self): - expect_failure = "\ntry: assert Device[Device.DEFAULT].compiler is None;\nexcept Exception: pass" - if Device.DEFAULT == "CPU": from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler try: _, _ = CPULLVMCompiler(), ClangJITCompiler() @@ -56,10 +54,10 @@ class TestDevice(unittest.TestCase): shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_LLVM": "1"}) subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, HIPCompiler)"'], shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_LLVM": "0"}) + subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, AMDLLVMCompiler)"'], + shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_CC": "LLVM"}) subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, HIPCompiler)"'], - shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1"}) - subprocess.run([f'python3 -c "{imports}; {expect_failure}"'], - shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"}) + shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_CC": "HIP"}) else: self.skipTest("only run on CPU/AMD") @unittest.skipIf((WIN and CI) or (not Device.DEFAULT == "CPU"), "skipping windows test") diff --git a/tinygrad/device.py b/tinygrad/device.py index 455f6e8b3f..f8e5b5e915 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -308,8 +308,8 @@ class Compiled: # select forced compiler from global env var. forced_comps = set([self.comp_sets[val][1]] if self.comps_ctrl_var is not None and (val:=self.comps_ctrl_var.value) else []) - # add forced compilers from individual env vars. - forced_comps |= set(rc for en, rc in self.comp_sets.values() if en is not None and en.value == 1) + # add forced compilers from individual env vars (only if global env var is not set, as it takes precedence). + if not forced_comps: forced_comps |= set(rc for en, rc in self.comp_sets.values() if en is not None and en.value == 1) if len(forced_comps) > 1: raise RuntimeError(f"{self.device}: multiple compilers set in env {forced_comps}") # select remaining compilers (all or forced only) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index d468a3055f..9b3554f991 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -266,7 +266,7 @@ exit: %packed = phi i32 [%packed_bf8, %do_bf8], [%packed_fp8, %do_fp8]\n %trunc lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64), x.src[2]), (*x.arg,)) if x.src[0].dtype in (dtypes.fp8e4m3.vec(8), dtypes.fp8e5m2.vec(8)) else None), ]) - if self.arch.split(":")[0] == "gfx1100": + if self.arch.split(":")[0] in {"gfx1100", "gfx1151"}: self.extra_matcher += PatternMatcher([ (UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)), lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8))), diff --git a/tinygrad/runtime/support/compiler_cpu.py b/tinygrad/runtime/support/compiler_cpu.py index 553706f3e1..cffd69254c 100644 --- a/tinygrad/runtime/support/compiler_cpu.py +++ b/tinygrad/runtime/support/compiler_cpu.py @@ -49,6 +49,8 @@ class LLVMCompiler(Compiler): else: self.passes = b'default' + # Create a per-instance context instead of using the global context to avoid shared state between parallel test processes + self.context = llvm.LLVMContextCreate() self.diag_msgs: list[str] = [] @llvm.LLVMDiagnosticHandler def handle_diag(diag_ref, _arg): @@ -57,15 +59,17 @@ class LLVMCompiler(Compiler): if severity == llvm.LLVMDSError: self.diag_msgs.append(msg) self.handle_diag = handle_diag - llvm.LLVMContextSetDiagnosticHandler(llvm.LLVMGetGlobalContext(), handle_diag, None) + llvm.LLVMContextSetDiagnosticHandler(self.context, handle_diag, None) super().__init__(f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}") - def __del__(self): llvm.LLVMDisposePassBuilderOptions(self.pbo) + def __del__(self): + llvm.LLVMDisposePassBuilderOptions(self.pbo) + llvm.LLVMContextDispose(self.context) def compile(self, src:str) -> bytes: self.diag_msgs.clear() src_buf = llvm.LLVMCreateMemoryBufferWithMemoryRangeCopy(ctypes.create_string_buffer(src_bytes:=src.encode()), len(src_bytes), b'src') - mod = expect(llvm.LLVMParseIRInContext(llvm.LLVMGetGlobalContext(), src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m) + mod = expect(llvm.LLVMParseIRInContext(self.context, src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m) expect(llvm.LLVMVerifyModule(mod, llvm.LLVMReturnStatusAction, err:=cerr()), err) expect(llvm.LLVMRunPasses(mod, self.passes, self.target_machine, self.pbo), 'failed to run passes') if DEBUG >= 7: print(ctypes.string_at(llvm.LLVMPrintModuleToString(mod)).decode()) diff --git a/tinygrad/uop/decompositions.py b/tinygrad/uop/decompositions.py index 54985c02be..82dfe67316 100644 --- a/tinygrad/uop/decompositions.py +++ b/tinygrad/uop/decompositions.py @@ -246,12 +246,10 @@ def xlog2(d:UOp) -> UOp: # log2(Inf) = Inf r = d.ne(math.inf).where(r, r.const_like(math.inf)) + # log2(0) = -Inf (handle both +0.0 and -0.0) + r = d.ne(0.0).where(r, r.const_like(-math.inf)) # log2(x) = NaN for x < 0 r = (d<-0.0).where(r.const_like(math.nan), r) - # log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true. - # log2_zero = the value of unmasked xlog2(0.0). - log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype.scalar()] - r = r.ne(log2_zero).where(r, r.const_like(-math.inf)) # log2(NaN) = NaN r = d.ne(d).where(r.const_like(math.nan), r) # log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.