mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
fix local tests for AMD_LLVM (#13738)
* fix local tests for AMD_LLVM * fix linters * skip that for now * fix segfault
This commit is contained in:
@@ -30,8 +30,6 @@ class TestDevice(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(WIN and CI, "skipping windows test") # TODO: subproccess causes memory violation?
|
||||
def test_env_overwrite_default_compiler(self):
|
||||
expect_failure = "\ntry: assert Device[Device.DEFAULT].compiler is None;\nexcept Exception: pass"
|
||||
|
||||
if Device.DEFAULT == "CPU":
|
||||
from tinygrad.runtime.support.compiler_cpu import CPULLVMCompiler, ClangJITCompiler
|
||||
try: _, _ = CPULLVMCompiler(), ClangJITCompiler()
|
||||
@@ -56,10 +54,10 @@ class TestDevice(unittest.TestCase):
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_LLVM": "1"})
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, HIPCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_LLVM": "0"})
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, AMDLLVMCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_CC": "LLVM"})
|
||||
subprocess.run([f'python3 -c "{imports}; assert isinstance(Device[Device.DEFAULT].compiler, HIPCompiler)"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1"})
|
||||
subprocess.run([f'python3 -c "{imports}; {expect_failure}"'],
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_HIP": "1", "AMD_LLVM": "1"})
|
||||
shell=True, check=True, env={**os.environ, "DEV": "AMD", "AMD_CC": "HIP"})
|
||||
else: self.skipTest("only run on CPU/AMD")
|
||||
|
||||
@unittest.skipIf((WIN and CI) or (not Device.DEFAULT == "CPU"), "skipping windows test")
|
||||
|
||||
@@ -308,8 +308,8 @@ class Compiled:
|
||||
# select forced compiler from global env var.
|
||||
forced_comps = set([self.comp_sets[val][1]] if self.comps_ctrl_var is not None and (val:=self.comps_ctrl_var.value) else [])
|
||||
|
||||
# add forced compilers from individual env vars.
|
||||
forced_comps |= set(rc for en, rc in self.comp_sets.values() if en is not None and en.value == 1)
|
||||
# add forced compilers from individual env vars (only if global env var is not set, as it takes precedence).
|
||||
if not forced_comps: forced_comps |= set(rc for en, rc in self.comp_sets.values() if en is not None and en.value == 1)
|
||||
if len(forced_comps) > 1: raise RuntimeError(f"{self.device}: multiple compilers set in env {forced_comps}")
|
||||
|
||||
# select remaining compilers (all or forced only)
|
||||
|
||||
@@ -266,7 +266,7 @@ exit: %packed = phi i32 [%packed_bf8, %do_bf8], [%packed_fp8, %do_fp8]\n %trunc
|
||||
lambda x: UOp(Ops.WMMA, dtypes.float.vec(4), (x.src[0].bitcast(dtypes.uint64), x.src[1].bitcast(dtypes.uint64),
|
||||
x.src[2]), (*x.arg,)) if x.src[0].dtype in (dtypes.fp8e4m3.vec(8), dtypes.fp8e5m2.vec(8)) else None),
|
||||
])
|
||||
if self.arch.split(":")[0] == "gfx1100":
|
||||
if self.arch.split(":")[0] in {"gfx1100", "gfx1151"}:
|
||||
self.extra_matcher += PatternMatcher([
|
||||
(UPat(Ops.WMMA, name="x", dtype=dtypes.half.vec(8)),
|
||||
lambda x: UOp(Ops.WMMA, dtypes.half.vec(16), (x.src[0], x.src[1], x.src[2].cast(dtypes.half.vec(16))), (*x.arg,)).cast(dtypes.half.vec(8))),
|
||||
|
||||
@@ -49,6 +49,8 @@ class LLVMCompiler(Compiler):
|
||||
else:
|
||||
self.passes = b'default<O0>'
|
||||
|
||||
# Create a per-instance context instead of using the global context to avoid shared state between parallel test processes
|
||||
self.context = llvm.LLVMContextCreate()
|
||||
self.diag_msgs: list[str] = []
|
||||
@llvm.LLVMDiagnosticHandler
|
||||
def handle_diag(diag_ref, _arg):
|
||||
@@ -57,15 +59,17 @@ class LLVMCompiler(Compiler):
|
||||
if severity == llvm.LLVMDSError:
|
||||
self.diag_msgs.append(msg)
|
||||
self.handle_diag = handle_diag
|
||||
llvm.LLVMContextSetDiagnosticHandler(llvm.LLVMGetGlobalContext(), handle_diag, None)
|
||||
llvm.LLVMContextSetDiagnosticHandler(self.context, handle_diag, None)
|
||||
super().__init__(f"compile_llvm_{processor}_{feats}{'_jit' if self.jit else ''}{'_opt' if opt else ''}")
|
||||
|
||||
def __del__(self): llvm.LLVMDisposePassBuilderOptions(self.pbo)
|
||||
def __del__(self):
|
||||
llvm.LLVMDisposePassBuilderOptions(self.pbo)
|
||||
llvm.LLVMContextDispose(self.context)
|
||||
|
||||
def compile(self, src:str) -> bytes:
|
||||
self.diag_msgs.clear()
|
||||
src_buf = llvm.LLVMCreateMemoryBufferWithMemoryRangeCopy(ctypes.create_string_buffer(src_bytes:=src.encode()), len(src_bytes), b'src')
|
||||
mod = expect(llvm.LLVMParseIRInContext(llvm.LLVMGetGlobalContext(), src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m)
|
||||
mod = expect(llvm.LLVMParseIRInContext(self.context, src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m)
|
||||
expect(llvm.LLVMVerifyModule(mod, llvm.LLVMReturnStatusAction, err:=cerr()), err)
|
||||
expect(llvm.LLVMRunPasses(mod, self.passes, self.target_machine, self.pbo), 'failed to run passes')
|
||||
if DEBUG >= 7: print(ctypes.string_at(llvm.LLVMPrintModuleToString(mod)).decode())
|
||||
|
||||
@@ -246,12 +246,10 @@ def xlog2(d:UOp) -> UOp:
|
||||
|
||||
# log2(Inf) = Inf
|
||||
r = d.ne(math.inf).where(r, r.const_like(math.inf))
|
||||
# log2(0) = -Inf (handle both +0.0 and -0.0)
|
||||
r = d.ne(0.0).where(r, r.const_like(-math.inf))
|
||||
# log2(x) = NaN for x < 0
|
||||
r = (d<-0.0).where(r.const_like(math.nan), r)
|
||||
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
|
||||
# log2_zero = the value of unmasked xlog2(0.0).
|
||||
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype.scalar()]
|
||||
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
|
||||
# log2(NaN) = NaN
|
||||
r = d.ne(d).where(r.const_like(math.nan), r)
|
||||
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
|
||||
|
||||
Reference in New Issue
Block a user