From caded2f4137a71ab9df896609a6e5f150eb9f1a9 Mon Sep 17 00:00:00 2001 From: b1tg <33436708+b1tg@users.noreply.github.com> Date: Fri, 16 May 2025 14:03:20 +0800 Subject: [PATCH] llvm diagnostic error (#10267) * llvm diagnostic info * use decorator * better error reporting * fix mypy * collect all diag msgs * test diag error --------- Co-authored-by: b1tg Co-authored-by: chenyu --- test/test_amd_llvm.py | 20 ++++++++++++++++++++ tinygrad/runtime/ops_llvm.py | 11 +++++++++++ 2 files changed, 31 insertions(+) diff --git a/test/test_amd_llvm.py b/test/test_amd_llvm.py index 08b70cc5f1..9fb9fa333c 100644 --- a/test/test_amd_llvm.py +++ b/test/test_amd_llvm.py @@ -1,6 +1,7 @@ import unittest import numpy as np from tinygrad import Device +from tinygrad.device import CompileError from tinygrad.helpers import flat_mv if Device.DEFAULT=="AMD": from tinygrad.runtime.ops_amd import AMDAllocator, AMDDevice, AMDProgram @@ -28,5 +29,24 @@ entry: allocator._copyout(flat_mv(na.data), a) assert na == [0x1234567800000005] + def test_compiler_diag_error(self): + src = """ +@local_temp0 = internal unnamed_addr addrspace(3) global [{N} x float*] undef, align 16 +define amdgpu_kernel void @test(float* noalias align 32 %data0, half* noalias align 32 %data1, float* noalias align 32 %data2) #0 +{{ + %local_temp0 = addrspacecast [{N} x float*] addrspace(3)* @local_temp0 to [{N} x float*]* + %v178 = getelementptr inbounds float, float* %local_temp0, i32 1 + %v133 = getelementptr inbounds float, float* %data2, i32 1 + %v134 = load float, float* %v133 + store float %v134, float* %v178 + ret void +}} +""" + compiler = AMDLLVMCompiler("gfx1100") + compiler.compile(src.format(N=65536//8)) + with self.assertRaises(CompileError): + # llvm diagnostic: :0:0: local memory (65544) exceeds limit (65536) in function 'test' + compiler.compile(src.format(N=65536//8+1)) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 9e22f19a1b..1c1cb6ba5c 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -33,11 +33,21 @@ class LLVMCompiler(Compiler): else: self.passes = b'default' + self.diag_msgs: list[str] = [] + @ctypes.CFUNCTYPE(None, llvm.LLVMDiagnosticInfoRef, ctypes.c_void_p) + def handle_diag(diag_ref, _arg): + severity = llvm.LLVMGetDiagInfoSeverity(diag_ref) + msg = ctypes.string_at(llvm.LLVMGetDiagInfoDescription(diag_ref)).decode() + if severity == llvm.LLVMDSError: + self.diag_msgs.append(msg) + self.handle_diag = handle_diag + llvm.LLVMContextSetDiagnosticHandler(llvm.LLVMGetGlobalContext(), handle_diag, None) super().__init__(f"compile_llvm_{self.target_arch}{'_jit' if self.jit else ''}{'_opt' if opt else ''}") def __del__(self): llvm.LLVMDisposePassBuilderOptions(self.pbo) def compile(self, src:str) -> bytes: + self.diag_msgs.clear() src_buf = llvm.LLVMCreateMemoryBufferWithMemoryRangeCopy(ctypes.create_string_buffer(src_bytes:=src.encode()), len(src_bytes), b'src') mod = expect(llvm.LLVMParseIRInContext(llvm.LLVMGetGlobalContext(), src_buf, ctypes.pointer(m:=llvm.LLVMModuleRef()), err:=cerr()), err, m) expect(llvm.LLVMVerifyModule(mod, llvm.LLVMReturnStatusAction, err:=cerr()), err) @@ -48,6 +58,7 @@ class LLVMCompiler(Compiler): llvm.LLVMDisposeModule(mod) obj = ctypes.string_at(llvm.LLVMGetBufferStart(obj_buf), llvm.LLVMGetBufferSize(obj_buf)) llvm.LLVMDisposeMemoryBuffer(obj_buf) + if self.diag_msgs: raise RuntimeError("llvm diagnostic: " + "\n".join(self.diag_msgs)) return jit_loader(obj) if self.jit else obj def disassemble(self, lib:bytes): capstone_flatdump(lib)