fix local tests for AMD_LLVM (#13738)

* fix local tests for AMD_LLVM

* fix linters

* skip that for now

* fix segfault
This commit is contained in:
George Hotz
2025-12-17 12:23:46 -04:00
committed by GitHub
parent 7081014c73
commit b013244c38
5 changed files with 15 additions and 15 deletions

View File

@@ -30,8 +30,6 @@ class TestDevice(unittest.TestCase):
@unittest.skipIf(WIN and CI, "skipping windows test") # TODO: subproccess causes memory violation?
def test_env_overwrite_default_compiler(self):
expect_failure = "\ntry: assert Device[Device.DEFAULT].compiler is None;\nexcept Exception: pass"
if Device.DEFAULT == "CPU":
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
try: _, _ = CPULLVMCompiler(), ClangJITCompiler()
@@ -56,10 +54,10 @@ class TestDevice(unittest.TestCase):
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_LLVM": "1"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, HIPCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_LLVM": "0"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, AMDLLVMCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_CC": "LLVM"})
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, HIPCompiler)"'],
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1"})
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"})
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_CC": "HIP"})
else: self.skipTest("only run on CPU/AMD")
@unittest.skipIf((WIN and CI) or (not Device.DEFAULT == "CPU"), "skipping windows test")

View File

@@ -308,8 +308,8 @@ class Compiled:
# select forced compiler from global env var.
forced_comps = set([self.comp_sets[val][1]] if self.comps_ctrl_var is not None and (val:=self.comps_ctrl_var.value) else [])
# add forced compilers from individual env vars.
forced_comps |= set(rc for en, rc in self.comp_sets.values() if en is not None and en.value == 1)
# add forced compilers from individual env vars (only if global env var is not set, as it takes precedence).
if not forced_comps: forced_comps |= set(rc for en, rc in self.comp_sets.values() if en is not None and en.value == 1)
if len(forced_comps) > 1: raise RuntimeError(f"{self.device}: multiple compilers set in env {forced_comps}")
# select remaining compilers (all or forced only)

View File

@@ -266,7 +266,7 @@ exit: %packed = phi i32 [%packed_bf8, %do_bf8], [%packed_fp8, %do_fp8]\n %trunc
lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64),
x.src[2]), (*x.arg,)) if x.src[0].dtype in (dtypes.fp8e4m3.vec(8), dtypes.fp8e5m2.vec(8)) else None),
])
if self.arch.split(":")[0] == "gfx1100":
if self.arch.split(":")[0] in {"gfx1100", "gfx1151"}:
self.extra_matcher += PatternMatcher([
(UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),
lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8))),

View File

@@ -49,6 +49,8 @@ class LLVMCompiler(Compiler):
else:
self.passes = b'default<O0>'
# Create a per-instance context instead of using the global context to avoid shared state between parallel test processes
self.context = llvm.LLVMContextCreate()
self.diag_msgs: list[str] = []
@llvm.LLVMDiagnosticHandler
def handle_diag(diag_ref, _arg):
@@ -57,15 +59,17 @@ class LLVMCompiler(Compiler):
if severity == llvm.LLVMDSError:
self.diag_msgs.append(msg)
self.handle_diag = handle_diag
llvm.LLVMContextSetDiagnosticHandler(llvm.LLVMGetGlobalContext(), handle_diag, None)
llvm.LLVMContextSetDiagnosticHandler(self.context, handle_diag, None)
super().__init__(f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
def __del__(self): llvm.LLVMDisposePassBuilderOptions(self.pbo)
def __del__(self):
llvm.LLVMDisposePassBuilderOptions(self.pbo)
llvm.LLVMContextDispose(self.context)
def compile(self, src:str) -> bytes:
self.diag_msgs.clear()
src_buf = llvm.LLVMCreateMemoryBufferWithMemoryRangeCopy(ctypes.create_string_buffer(src_bytes:=src.encode()), len(src_bytes), b'src')
mod = expect(llvm.LLVMParseIRInContext(llvm.LLVMGetGlobalContext(), src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m)
mod = expect(llvm.LLVMParseIRInContext(self.context, src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m)
expect(llvm.LLVMVerifyModule(mod, llvm.LLVMReturnStatusAction, err:=cerr()), err)
expect(llvm.LLVMRunPasses(mod, self.passes, self.target_machine, self.pbo), 'failed to run passes')
if DEBUG >= 7: print(ctypes.string_at(llvm.LLVMPrintModuleToString(mod)).decode())

View File

@@ -246,12 +246,10 @@ def xlog2(d:UOp) -> UOp:
# log2(Inf) = Inf
r = d.ne(math.inf).where(r, r.const_like(math.inf))
# log2(0) = -Inf (handle both +0.0 and -0.0)
r = d.ne(0.0).where(r, r.const_like(-math.inf))
# log2(x) = NaN for x < 0
r = (d<-0.0).where(r.const_like(math.nan), r)
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
# log2_zero = the value of unmasked xlog2(0.0).
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype.scalar()]
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
# log2(NaN) = NaN
r = d.ne(d).where(r.const_like(math.nan), r)
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.